1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16
17 using namespace mlir;
18 using namespace mlir::bufferization;
19
20 namespace {
21 /// Bufferization of arith.constant. Replace with memref.get_global.
22 struct ConstantOpInterface
23 : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24 arith::ConstantOp> {
bufferize__anonbafe9d880111::ConstantOpInterface25 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26 const BufferizationOptions &options) const {
27 auto constantOp = cast<arith::ConstantOp>(op);
28
29 // TODO: Implement memory space for this op. E.g., by adding a memory_space
30 // attribute to ConstantOp.
31 if (options.defaultMemorySpace != static_cast<unsigned>(0))
32 return op->emitError("memory space not implemented yet");
33
34 // Only ranked tensors are supported.
35 if (!constantOp.getType().isa<RankedTensorType>())
36 return failure();
37
38 // Only constants inside a module are supported.
39 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
40 if (!moduleOp)
41 return failure();
42
43 // Create global memory segment and replace tensor with memref pointing to
44 // that memory segment.
45 FailureOr<memref::GlobalOp> globalOp =
46 getGlobalFor(constantOp, options.bufferAlignment);
47 if (failed(globalOp))
48 return failure();
49 memref::GlobalOp globalMemref = *globalOp;
50 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51 rewriter, op, globalMemref.getType(), globalMemref.getName());
52
53 return success();
54 }
55
isWritable__anonbafe9d880111::ConstantOpInterface56 bool isWritable(Operation *op, Value value,
57 const AnalysisState &state) const {
58 // Memory locations returned by memref::GetGlobalOp may not be written to.
59 assert(value.isa<OpResult>());
60 return false;
61 }
62 };
63
64 struct IndexCastOpInterface
65 : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
66 arith::IndexCastOp> {
bufferizesToMemoryRead__anonbafe9d880111::IndexCastOpInterface67 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
68 const AnalysisState &state) const {
69 return false;
70 }
71
bufferizesToMemoryWrite__anonbafe9d880111::IndexCastOpInterface72 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
73 const AnalysisState &state) const {
74 return false;
75 }
76
getAliasingOpResult__anonbafe9d880111::IndexCastOpInterface77 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state) const {
79 return {op->getResult(0)};
80 }
81
bufferRelation__anonbafe9d880111::IndexCastOpInterface82 BufferRelation bufferRelation(Operation *op, OpResult opResult,
83 const AnalysisState &state) const {
84 return BufferRelation::Equivalent;
85 }
86
bufferize__anonbafe9d880111::IndexCastOpInterface87 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88 const BufferizationOptions &options) const {
89 auto castOp = cast<arith::IndexCastOp>(op);
90 auto resultTensorType = castOp.getType().cast<TensorType>();
91
92 FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
93 if (failed(source))
94 return failure();
95 auto sourceType = source->getType().cast<BaseMemRefType>();
96
97 // Result type should have same layout and address space as the source type.
98 BaseMemRefType resultType;
99 if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
100 resultType = MemRefType::get(
101 rankedMemRefType.getShape(), resultTensorType.getElementType(),
102 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
103 } else {
104 auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
105 resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
106 unrankedMemrefType.getMemorySpace());
107 }
108
109 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
110 *source);
111 return success();
112 }
113 };
114
115 /// Bufferization of arith.select. Just replace the operands.
116 struct SelectOpInterface
117 : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
118 arith::SelectOp> {
bufferizesToMemoryRead__anonbafe9d880111::SelectOpInterface119 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
120 const AnalysisState &state) const {
121 return false;
122 }
123
bufferizesToMemoryWrite__anonbafe9d880111::SelectOpInterface124 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
125 const AnalysisState &state) const {
126 return false;
127 }
128
getAliasingOpResult__anonbafe9d880111::SelectOpInterface129 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
130 const AnalysisState &state) const {
131 return {op->getOpResult(0) /*result*/};
132 }
133
134 SmallVector<OpOperand *>
getAliasingOpOperand__anonbafe9d880111::SelectOpInterface135 getAliasingOpOperand(Operation *op, OpResult opResult,
136 const AnalysisState &state) const {
137 return {&op->getOpOperand(1) /*true_value*/,
138 &op->getOpOperand(2) /*false_value*/};
139 }
140
bufferize__anonbafe9d880111::SelectOpInterface141 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142 const BufferizationOptions &options) const {
143 auto selectOp = cast<arith::SelectOp>(op);
144 Location loc = selectOp.getLoc();
145
146 // TODO: It would be more efficient to copy the result of the `select` op
147 // instead of its OpOperands. In the worst case, 2 copies are inserted at
148 // the moment (one for each tensor). When copying the op result, only one
149 // copy would be needed.
150 FailureOr<Value> maybeTrueBuffer =
151 getBuffer(rewriter, selectOp.getTrueValue(), options);
152 FailureOr<Value> maybeFalseBuffer =
153 getBuffer(rewriter, selectOp.getFalseValue(), options);
154 if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
155 return failure();
156 Value trueBuffer = *maybeTrueBuffer;
157 Value falseBuffer = *maybeFalseBuffer;
158 BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>();
159 BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>();
160 if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt())
161 return op->emitError("inconsistent memory space on true/false operands");
162
163 // The "true" and the "false" operands must have the same type. If the
164 // buffers have different types, they differ only in their layout map. Cast
165 // both of them to the most dynamic MemRef type.
166 if (trueBuffer.getType() != falseBuffer.getType()) {
167 auto trueType = trueBuffer.getType().cast<MemRefType>();
168 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
169 SmallVector<int64_t> dynamicStrides(trueType.getRank(),
170 ShapedType::kDynamicStrideOrOffset);
171 AffineMap stridedLayout = makeStridedLinearLayoutMap(
172 dynamicStrides, dynamicOffset, op->getContext());
173 auto castedType =
174 MemRefType::get(trueType.getShape(), trueType.getElementType(),
175 stridedLayout, trueType.getMemorySpaceAsInt());
176 trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
177 falseBuffer =
178 rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
179 }
180
181 replaceOpWithNewBufferizedOp<arith::SelectOp>(
182 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
183 return success();
184 }
185
bufferRelation__anonbafe9d880111::SelectOpInterface186 BufferRelation bufferRelation(Operation *op, OpResult opResult,
187 const AnalysisState &state) const {
188 return BufferRelation::None;
189 }
190 };
191
192 } // namespace
193
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)194 void mlir::arith::registerBufferizableOpInterfaceExternalModels(
195 DialectRegistry ®istry) {
196 registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
197 ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
198 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
199 SelectOp::attachInterface<SelectOpInterface>(*ctx);
200 });
201 }
202