1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 namespace {
21 /// Bufferization of arith.constant. Replace with memref.get_global.
22 struct ConstantOpInterface
23     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24                                                     arith::ConstantOp> {
25   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26                           const BufferizationOptions &options) const {
27     auto constantOp = cast<arith::ConstantOp>(op);
28 
29     // TODO: Implement memory space for this op. E.g., by adding a memory_space
30     // attribute to ConstantOp.
31     if (options.defaultMemorySpace != static_cast<unsigned>(0))
32       return op->emitError("memory space not implemented yet");
33 
34     // Only ranked tensors are supported.
35     if (!constantOp.getType().isa<RankedTensorType>())
36       return failure();
37 
38     // Only constants inside a module are supported.
39     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
40     if (!moduleOp)
41       return failure();
42 
43     // Create global memory segment and replace tensor with memref pointing to
44     // that memory segment.
45     FailureOr<memref::GlobalOp> globalOp =
46         getGlobalFor(constantOp, options.bufferAlignment);
47     if (failed(globalOp))
48       return failure();
49     memref::GlobalOp globalMemref = *globalOp;
50     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51         rewriter, op, globalMemref.getType(), globalMemref.getName());
52 
53     return success();
54   }
55 
56   bool isWritable(Operation *op, Value value,
57                   const AnalysisState &state) const {
58     // Memory locations returned by memref::GetGlobalOp may not be written to.
59     assert(value.isa<OpResult>());
60     return false;
61   }
62 };
63 
64 struct IndexCastOpInterface
65     : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
66                                                     arith::IndexCastOp> {
67   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
68                               const AnalysisState &state) const {
69     return false;
70   }
71 
72   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
73                                const AnalysisState &state) const {
74     return false;
75   }
76 
77   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
78                                             const AnalysisState &state) const {
79     return {op->getResult(0)};
80   }
81 
82   BufferRelation bufferRelation(Operation *op, OpResult opResult,
83                                 const AnalysisState &state) const {
84     return BufferRelation::Equivalent;
85   }
86 
87   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88                           const BufferizationOptions &options) const {
89     auto castOp = cast<arith::IndexCastOp>(op);
90     auto resultTensorType = castOp.getType().cast<TensorType>();
91 
92     FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
93     if (failed(source))
94       return failure();
95     auto sourceType = source->getType().cast<BaseMemRefType>();
96 
97     // Result type should have same layout and address space as the source type.
98     BaseMemRefType resultType;
99     if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
100       resultType = MemRefType::get(
101           rankedMemRefType.getShape(), resultTensorType.getElementType(),
102           rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
103     } else {
104       auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
105       resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
106                                            unrankedMemrefType.getMemorySpace());
107     }
108 
109     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
110                                                      *source);
111     return success();
112   }
113 };
114 
115 /// Bufferization of arith.select. Just replace the operands.
116 struct SelectOpInterface
117     : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
118                                                     arith::SelectOp> {
119   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
120                               const AnalysisState &state) const {
121     return false;
122   }
123 
124   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
125                                const AnalysisState &state) const {
126     return false;
127   }
128 
129   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
130                                             const AnalysisState &state) const {
131     return {op->getOpResult(0) /*result*/};
132   }
133 
134   SmallVector<OpOperand *>
135   getAliasingOpOperand(Operation *op, OpResult opResult,
136                        const AnalysisState &state) const {
137     return {&op->getOpOperand(1) /*true_value*/,
138             &op->getOpOperand(2) /*false_value*/};
139   }
140 
141   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142                           const BufferizationOptions &options) const {
143     auto selectOp = cast<arith::SelectOp>(op);
144     Location loc = selectOp.getLoc();
145 
146     // TODO: It would be more efficient to copy the result of the `select` op
147     // instead of its OpOperands. In the worst case, 2 copies are inserted at
148     // the moment (one for each tensor). When copying the op result, only one
149     // copy would be needed.
150     FailureOr<Value> maybeTrueBuffer =
151         getBuffer(rewriter, selectOp.getTrueValue(), options);
152     FailureOr<Value> maybeFalseBuffer =
153         getBuffer(rewriter, selectOp.getFalseValue(), options);
154     if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
155       return failure();
156     Value trueBuffer = *maybeTrueBuffer;
157     Value falseBuffer = *maybeFalseBuffer;
158     BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>();
159     BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>();
160     if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt())
161       return op->emitError("inconsistent memory space on true/false operands");
162 
163     // The "true" and the "false" operands must have the same type. If the
164     // buffers have different types, they differ only in their layout map. Cast
165     // both of them to the most dynamic MemRef type.
166     if (trueBuffer.getType() != falseBuffer.getType()) {
167       auto trueType = trueBuffer.getType().cast<MemRefType>();
168       int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
169       SmallVector<int64_t> dynamicStrides(trueType.getRank(),
170                                           ShapedType::kDynamicStrideOrOffset);
171       AffineMap stridedLayout = makeStridedLinearLayoutMap(
172           dynamicStrides, dynamicOffset, op->getContext());
173       auto castedType =
174           MemRefType::get(trueType.getShape(), trueType.getElementType(),
175                           stridedLayout, trueType.getMemorySpaceAsInt());
176       trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
177       falseBuffer =
178           rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
179     }
180 
181     replaceOpWithNewBufferizedOp<arith::SelectOp>(
182         rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
183     return success();
184   }
185 
186   BufferRelation bufferRelation(Operation *op, OpResult opResult,
187                                 const AnalysisState &state) const {
188     return BufferRelation::None;
189   }
190 };
191 
192 } // namespace
193 
194 void mlir::arith::registerBufferizableOpInterfaceExternalModels(
195     DialectRegistry &registry) {
196   registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
197     ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
198     IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
199     SelectOp::attachInterface<SelectOpInterface>(*ctx);
200   });
201 }
202