1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Operation.h"
15
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::vector;
19
20 namespace mlir {
21 namespace vector {
22 namespace {
23
24 /// Bufferization of vector.transfer_read. Replaced with a new
25 /// vector.transfer_read that operates on a memref.
26 struct TransferReadOpInterface
27 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
28 vector::TransferReadOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferReadOpInterface29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30 const AnalysisState &state) const {
31 assert(opOperand.get().getType().isa<RankedTensorType>() &&
32 "only tensor types expected");
33 return true;
34 }
35
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferReadOpInterface36 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
37 const AnalysisState &state) const {
38 assert(opOperand.get().getType().isa<RankedTensorType>() &&
39 "only tensor types expected");
40 return false;
41 }
42
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferReadOpInterface43 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
44 const AnalysisState &state) const {
45 return {};
46 }
47
bufferizemlir::vector::__anon48b3dc910111::TransferReadOpInterface48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49 const BufferizationOptions &options) const {
50 auto readOp = cast<vector::TransferReadOp>(op);
51 assert(readOp.getShapedType().isa<TensorType>() &&
52 "only tensor types expected");
53 FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
54 if (failed(buffer))
55 return failure();
56 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
57 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
58 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
59 readOp.getInBoundsAttr());
60 return success();
61 }
62 };
63
64 /// Bufferization of vector.transfer_write. Replace with a new
65 /// vector.transfer_write that operates on a memref.
66 struct TransferWriteOpInterface
67 : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
68 vector::TransferWriteOp> {
bufferizesToMemoryReadmlir::vector::__anon48b3dc910111::TransferWriteOpInterface69 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
70 const AnalysisState &state) const {
71 assert(opOperand.get().getType().isa<TensorType>() &&
72 "only tensor types expected");
73 return true;
74 }
75
bufferizesToMemoryWritemlir::vector::__anon48b3dc910111::TransferWriteOpInterface76 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
77 const AnalysisState &state) const {
78 assert(opOperand.get().getType().isa<TensorType>() &&
79 "only tensor types expected");
80 return true;
81 }
82
getAliasingOpResultmlir::vector::__anon48b3dc910111::TransferWriteOpInterface83 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
84 const AnalysisState &state) const {
85 assert(opOperand.get().getType().isa<TensorType>() &&
86 "only tensor types expected");
87 return {op->getOpResult(0)};
88 }
89
bufferRelationmlir::vector::__anon48b3dc910111::TransferWriteOpInterface90 BufferRelation bufferRelation(Operation *op, OpResult opResult,
91 const AnalysisState &state) const {
92 return BufferRelation::Equivalent;
93 }
94
bufferizemlir::vector::__anon48b3dc910111::TransferWriteOpInterface95 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96 const BufferizationOptions &options) const {
97 auto writeOp = cast<vector::TransferWriteOp>(op);
98 assert(writeOp.getShapedType().isa<TensorType>() &&
99 "only tensor types expected");
100
101 // Create a new transfer_write on buffer that doesn't have a return value.
102 FailureOr<Value> resultBuffer =
103 getBuffer(rewriter, writeOp.getSource(), options);
104 if (failed(resultBuffer))
105 return failure();
106 rewriter.create<vector::TransferWriteOp>(
107 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
108 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
109 writeOp.getMask(), writeOp.getInBoundsAttr());
110 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
111
112 return success();
113 }
114 };
115
116 } // namespace
117 } // namespace vector
118 } // namespace mlir
119
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)120 void mlir::vector::registerBufferizableOpInterfaceExternalModels(
121 DialectRegistry ®istry) {
122 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
123 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
124 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
125 });
126 }
127