1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::vector;
19 
20 namespace mlir {
21 namespace vector {
22 namespace {
23 
24 /// Bufferization of vector.transfer_read. Replaced with a new
25 /// vector.transfer_read that operates on a memref.
26 struct TransferReadOpInterface
27     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
28                                                     vector::TransferReadOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const BufferizationState &state) const {
31     assert(opOperand.get().getType().isa<RankedTensorType>() &&
32            "only tensor types expected");
33     return true;
34   }
35 
36   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37                                const BufferizationState &state) const {
38     assert(opOperand.get().getType().isa<RankedTensorType>() &&
39            "only tensor types expected");
40     return false;
41   }
42 
43   SmallVector<OpResult>
44   getAliasingOpResult(Operation *op, OpOperand &opOperand,
45                       const BufferizationState &state) const {
46     return {};
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           const BufferizationState &state) const {
51     auto readOp = cast<vector::TransferReadOp>(op);
52     assert(readOp.getShapedType().isa<TensorType>() &&
53            "only tensor types expected");
54 
55     // TransferReadOp always reads from the bufferized op.source().
56     Value buffer =
57         *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
58     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59         rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
60         readOp.permutation_map(), readOp.padding(), readOp.mask(),
61         readOp.in_boundsAttr());
62     return success();
63   }
64 };
65 
66 /// Bufferization of vector.transfer_write. Replace with a new
67 /// vector.transfer_write that operates on a memref.
68 struct TransferWriteOpInterface
69     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
70                                                     vector::TransferWriteOp> {
71   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
72                               const BufferizationState &state) const {
73     assert(opOperand.get().getType().isa<TensorType>() &&
74            "only tensor types expected");
75     return true;
76   }
77 
78   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
79                                const BufferizationState &state) const {
80     assert(opOperand.get().getType().isa<TensorType>() &&
81            "only tensor types expected");
82     return true;
83   }
84 
85   SmallVector<OpResult>
86   getAliasingOpResult(Operation *op, OpOperand &opOperand,
87                       const BufferizationState &state) const {
88     assert(opOperand.get().getType().isa<TensorType>() &&
89            "only tensor types expected");
90     return {op->getOpResult(0)};
91   }
92 
93   BufferRelation bufferRelation(Operation *op, OpResult opResult,
94                                 const BufferizationState &state) const {
95     return BufferRelation::Equivalent;
96   }
97 
98   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
99                           const BufferizationState &state) const {
100     auto writeOp = cast<vector::TransferWriteOp>(op);
101     assert(writeOp.getShapedType().isa<TensorType>() &&
102            "only tensor types expected");
103 
104     // Create a new transfer_write on buffer that doesn't have a return value.
105     // Leave the previous transfer_write to dead code as it still has uses at
106     // this point.
107     FailureOr<Value> resultBuffer =
108         state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
109     if (failed(resultBuffer))
110       return failure();
111     rewriter.create<vector::TransferWriteOp>(
112         writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
113         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
114     replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
115 
116     return success();
117   }
118 };
119 
120 } // namespace
121 } // namespace vector
122 } // namespace mlir
123 
124 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
125     DialectRegistry &registry) {
126   registry.addOpInterface<TransferReadOp, TransferReadOpInterface>();
127   registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>();
128 }
129