1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 namespace {
21 /// Bufferization of arith.constant. Replace with memref.get_global.
22 struct ConstantOpInterface
23     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24                                                     arith::ConstantOp> {
25   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26                           const BufferizationState &state) const {
27     auto constantOp = cast<arith::ConstantOp>(op);
28 
29     // Only ranked tensors are supported.
30     if (!constantOp.getType().isa<RankedTensorType>())
31       return failure();
32 
33     // Only constants inside a module are supported.
34     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
35     if (!moduleOp)
36       return failure();
37 
38     // Create global memory segment and replace tensor with memref pointing to
39     // that memory segment.
40     FailureOr<memref::GlobalOp> globalOp =
41         getGlobalFor(constantOp, state.getOptions().bufferAlignment);
42     if (failed(globalOp))
43       return failure();
44     memref::GlobalOp globalMemref = globalOp.getValue();
45     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
46         rewriter, op, globalMemref.type(), globalMemref.getName());
47 
48     return success();
49   }
50 
51   bool isWritable(Operation *op, Value value,
52                   const BufferizationState &state) const {
53     // Memory locations returned by memref::GetGlobalOp may not be written to.
54     assert(value.isa<OpResult>());
55     return false;
56   }
57 };
58 
59 struct IndexCastOpInterface
60     : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
61                                                     arith::IndexCastOp> {
62   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
63                               const BufferizationState &state) const {
64     return false;
65   }
66 
67   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
68                                const BufferizationState &state) const {
69     return false;
70   }
71 
72   SmallVector<OpResult>
73   getAliasingOpResult(Operation *op, OpOperand &opOperand,
74                       const BufferizationState &state) const {
75     return {op->getResult(0)};
76   }
77 
78   BufferRelation bufferRelation(Operation *op, OpResult opResult,
79                                 const BufferizationState &state) const {
80     return BufferRelation::Equivalent;
81   }
82 
83   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
84                           const BufferizationState &state) const {
85     auto castOp = cast<arith::IndexCastOp>(op);
86 
87     Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
88     auto sourceType = source.getType().cast<BaseMemRefType>();
89 
90     // Result type should have same layout and address space as the source type.
91     MemRefLayoutAttrInterface layout = {};
92     if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
93       layout = rankedMemRefType.getLayout();
94     Type resultType =
95         getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
96                       layout, sourceType.getMemorySpace());
97 
98     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
99                                                      source);
100     return success();
101   }
102 };
103 
104 /// Bufferization of arith.select. Just replace the operands.
105 struct SelectOpInterface
106     : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
107                                                     arith::SelectOp> {
108   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
109                               const BufferizationState &state) const {
110     return false;
111   }
112 
113   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
114                                const BufferizationState &state) const {
115     return false;
116   }
117 
118   SmallVector<OpResult>
119   getAliasingOpResult(Operation *op, OpOperand &opOperand,
120                       const BufferizationState &state) const {
121     return {op->getOpResult(0) /*result*/};
122   }
123 
124   SmallVector<OpOperand *>
125   getAliasingOpOperand(Operation *op, OpResult opResult,
126                        const BufferizationState &state) const {
127     return {&op->getOpOperand(1) /*true_value*/,
128             &op->getOpOperand(2) /*false_value*/};
129   }
130 
131   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
132                           const BufferizationState &state) const {
133     auto selectOp = cast<arith::SelectOp>(op);
134 
135     // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
136     // TODO: It would be more efficient to copy the result of the `select` op
137     // instead of its OpOperands. In the worst case, 2 copies are inserted at
138     // the moment (one for each tensor). When copying the op result, only one
139     // copy would be needed.
140     Value trueBuffer =
141         *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
142     Value falseBuffer =
143         *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
144     replaceOpWithNewBufferizedOp<arith::SelectOp>(
145         rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
146     return success();
147   }
148 
149   BufferRelation bufferRelation(Operation *op, OpResult opResult,
150                                 const BufferizationState &state) const {
151     return BufferRelation::None;
152   }
153 };
154 
155 } // namespace
156 
157 void mlir::arith::registerBufferizableOpInterfaceExternalModels(
158     DialectRegistry &registry) {
159   registry.addOpInterface<ConstantOp, ConstantOpInterface>();
160   registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
161   registry.addOpInterface<SelectOp, SelectOpInterface>();
162 }
163