1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::vector;
19 
20 namespace mlir {
21 namespace vector {
22 namespace {
23 
24 /// Bufferization of vector.transfer_read. Replaced with a new
25 /// vector.transfer_read that operates on a memref.
26 struct TransferReadOpInterface
27     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
28                                                     vector::TransferReadOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     assert(opOperand.get().getType().isa<RankedTensorType>() &&
32            "only tensor types expected");
33     return true;
34   }
35 
36   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37                                const AnalysisState &state) const {
38     assert(opOperand.get().getType().isa<RankedTensorType>() &&
39            "only tensor types expected");
40     return false;
41   }
42 
43   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
44                                             const AnalysisState &state) const {
45     return {};
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           const BufferizationOptions &options) const {
50     auto readOp = cast<vector::TransferReadOp>(op);
51     assert(readOp.getShapedType().isa<TensorType>() &&
52            "only tensor types expected");
53     FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
54     if (failed(buffer))
55       return failure();
56     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
57         rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
58         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
59         readOp.getInBoundsAttr());
60     return success();
61   }
62 };
63 
64 /// Bufferization of vector.transfer_write. Replace with a new
65 /// vector.transfer_write that operates on a memref.
66 struct TransferWriteOpInterface
67     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
68                                                     vector::TransferWriteOp> {
69   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
70                               const AnalysisState &state) const {
71     assert(opOperand.get().getType().isa<TensorType>() &&
72            "only tensor types expected");
73     return true;
74   }
75 
76   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
77                                const AnalysisState &state) const {
78     assert(opOperand.get().getType().isa<TensorType>() &&
79            "only tensor types expected");
80     return true;
81   }
82 
83   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
84                                             const AnalysisState &state) const {
85     assert(opOperand.get().getType().isa<TensorType>() &&
86            "only tensor types expected");
87     return {op->getOpResult(0)};
88   }
89 
90   BufferRelation bufferRelation(Operation *op, OpResult opResult,
91                                 const AnalysisState &state) const {
92     return BufferRelation::Equivalent;
93   }
94 
95   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96                           const BufferizationOptions &options) const {
97     auto writeOp = cast<vector::TransferWriteOp>(op);
98     assert(writeOp.getShapedType().isa<TensorType>() &&
99            "only tensor types expected");
100 
101     // Create a new transfer_write on buffer that doesn't have a return value.
102     FailureOr<Value> resultBuffer =
103         getBuffer(rewriter, writeOp.getSource(), options);
104     if (failed(resultBuffer))
105       return failure();
106     rewriter.create<vector::TransferWriteOp>(
107         writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
108         writeOp.getIndices(), writeOp.getPermutationMapAttr(),
109         writeOp.getMask(), writeOp.getInBoundsAttr());
110     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
111 
112     return success();
113   }
114 };
115 
116 } // namespace
117 } // namespace vector
118 } // namespace mlir
119 
120 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
121     DialectRegistry &registry) {
122   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
123     TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
124     TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
125   });
126 }
127