1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2075e3fddSMatthias Springer //
3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6075e3fddSMatthias Springer //
7075e3fddSMatthias Springer //===----------------------------------------------------------------------===//
8075e3fddSMatthias Springer
9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h"
15075e3fddSMatthias Springer #include "mlir/IR/Operation.h"
16075e3fddSMatthias Springer
17dec8af70SRiver Riddle using namespace mlir;
18075e3fddSMatthias Springer using namespace mlir::bufferization;
19075e3fddSMatthias Springer
20075e3fddSMatthias Springer namespace {
21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global.
22075e3fddSMatthias Springer struct ConstantOpInterface
23075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24075e3fddSMatthias Springer arith::ConstantOp> {
bufferize__anonbafe9d880111::ConstantOpInterface25075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26b55d55ecSMatthias Springer const BufferizationOptions &options) const {
27075e3fddSMatthias Springer auto constantOp = cast<arith::ConstantOp>(op);
28075e3fddSMatthias Springer
29c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op. E.g., by adding a memory_space
30c0b0b6a0SMatthias Springer // attribute to ConstantOp.
31c0b0b6a0SMatthias Springer if (options.defaultMemorySpace != static_cast<unsigned>(0))
32c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet");
33c0b0b6a0SMatthias Springer
34075e3fddSMatthias Springer // Only ranked tensors are supported.
35075e3fddSMatthias Springer if (!constantOp.getType().isa<RankedTensorType>())
36075e3fddSMatthias Springer return failure();
37075e3fddSMatthias Springer
38075e3fddSMatthias Springer // Only constants inside a module are supported.
39075e3fddSMatthias Springer auto moduleOp = constantOp->getParentOfType<ModuleOp>();
40075e3fddSMatthias Springer if (!moduleOp)
41075e3fddSMatthias Springer return failure();
42075e3fddSMatthias Springer
43075e3fddSMatthias Springer // Create global memory segment and replace tensor with memref pointing to
44075e3fddSMatthias Springer // that memory segment.
45ab47418dSMatthias Springer FailureOr<memref::GlobalOp> globalOp =
46b55d55ecSMatthias Springer getGlobalFor(constantOp, options.bufferAlignment);
47ab47418dSMatthias Springer if (failed(globalOp))
48ab47418dSMatthias Springer return failure();
496d5fc1e3SKazu Hirata memref::GlobalOp globalMemref = *globalOp;
50075e3fddSMatthias Springer replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51*136d746eSJacques Pienaar rewriter, op, globalMemref.getType(), globalMemref.getName());
52075e3fddSMatthias Springer
53075e3fddSMatthias Springer return success();
54075e3fddSMatthias Springer }
55075e3fddSMatthias Springer
isWritable__anonbafe9d880111::ConstantOpInterface56075e3fddSMatthias Springer bool isWritable(Operation *op, Value value,
579597b16aSMatthias Springer const AnalysisState &state) const {
58075e3fddSMatthias Springer // Memory locations returned by memref::GetGlobalOp may not be written to.
59075e3fddSMatthias Springer assert(value.isa<OpResult>());
60075e3fddSMatthias Springer return false;
61075e3fddSMatthias Springer }
62075e3fddSMatthias Springer };
63075e3fddSMatthias Springer
64075e3fddSMatthias Springer struct IndexCastOpInterface
65075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
66075e3fddSMatthias Springer arith::IndexCastOp> {
bufferizesToMemoryRead__anonbafe9d880111::IndexCastOpInterface67075e3fddSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
689597b16aSMatthias Springer const AnalysisState &state) const {
69075e3fddSMatthias Springer return false;
70075e3fddSMatthias Springer }
71075e3fddSMatthias Springer
bufferizesToMemoryWrite__anonbafe9d880111::IndexCastOpInterface72075e3fddSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
739597b16aSMatthias Springer const AnalysisState &state) const {
74075e3fddSMatthias Springer return false;
75075e3fddSMatthias Springer }
76075e3fddSMatthias Springer
getAliasingOpResult__anonbafe9d880111::IndexCastOpInterface779597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
789597b16aSMatthias Springer const AnalysisState &state) const {
79585a8a32SMatthias Springer return {op->getResult(0)};
80075e3fddSMatthias Springer }
81075e3fddSMatthias Springer
bufferRelation__anonbafe9d880111::IndexCastOpInterface82075e3fddSMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
839597b16aSMatthias Springer const AnalysisState &state) const {
84075e3fddSMatthias Springer return BufferRelation::Equivalent;
85075e3fddSMatthias Springer }
86075e3fddSMatthias Springer
bufferize__anonbafe9d880111::IndexCastOpInterface87075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88b55d55ecSMatthias Springer const BufferizationOptions &options) const {
89075e3fddSMatthias Springer auto castOp = cast<arith::IndexCastOp>(op);
9012e41d92SMatthias Springer auto resultTensorType = castOp.getType().cast<TensorType>();
91075e3fddSMatthias Springer
925d50f51cSMatthias Springer FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
935d50f51cSMatthias Springer if (failed(source))
945d50f51cSMatthias Springer return failure();
955d50f51cSMatthias Springer auto sourceType = source->getType().cast<BaseMemRefType>();
96075e3fddSMatthias Springer
97075e3fddSMatthias Springer // Result type should have same layout and address space as the source type.
9812e41d92SMatthias Springer BaseMemRefType resultType;
9912e41d92SMatthias Springer if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
10012e41d92SMatthias Springer resultType = MemRefType::get(
10112e41d92SMatthias Springer rankedMemRefType.getShape(), resultTensorType.getElementType(),
10212e41d92SMatthias Springer rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
10312e41d92SMatthias Springer } else {
10412e41d92SMatthias Springer auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
10512e41d92SMatthias Springer resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
10612e41d92SMatthias Springer unrankedMemrefType.getMemorySpace());
10712e41d92SMatthias Springer }
108075e3fddSMatthias Springer
1093c69bc4dSRiver Riddle replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
1105d50f51cSMatthias Springer *source);
111075e3fddSMatthias Springer return success();
112075e3fddSMatthias Springer }
113075e3fddSMatthias Springer };
114075e3fddSMatthias Springer
115dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands.
116dec8af70SRiver Riddle struct SelectOpInterface
117dec8af70SRiver Riddle : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
118dec8af70SRiver Riddle arith::SelectOp> {
bufferizesToMemoryRead__anonbafe9d880111::SelectOpInterface119dec8af70SRiver Riddle bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1209597b16aSMatthias Springer const AnalysisState &state) const {
121dec8af70SRiver Riddle return false;
122dec8af70SRiver Riddle }
123dec8af70SRiver Riddle
bufferizesToMemoryWrite__anonbafe9d880111::SelectOpInterface124dec8af70SRiver Riddle bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1259597b16aSMatthias Springer const AnalysisState &state) const {
126dec8af70SRiver Riddle return false;
127dec8af70SRiver Riddle }
128dec8af70SRiver Riddle
getAliasingOpResult__anonbafe9d880111::SelectOpInterface1299597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1309597b16aSMatthias Springer const AnalysisState &state) const {
131585a8a32SMatthias Springer return {op->getOpResult(0) /*result*/};
132dec8af70SRiver Riddle }
133dec8af70SRiver Riddle
134dec8af70SRiver Riddle SmallVector<OpOperand *>
getAliasingOpOperand__anonbafe9d880111::SelectOpInterface135dec8af70SRiver Riddle getAliasingOpOperand(Operation *op, OpResult opResult,
1369597b16aSMatthias Springer const AnalysisState &state) const {
137dec8af70SRiver Riddle return {&op->getOpOperand(1) /*true_value*/,
138dec8af70SRiver Riddle &op->getOpOperand(2) /*false_value*/};
139dec8af70SRiver Riddle }
140dec8af70SRiver Riddle
bufferize__anonbafe9d880111::SelectOpInterface141dec8af70SRiver Riddle LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142b55d55ecSMatthias Springer const BufferizationOptions &options) const {
143dec8af70SRiver Riddle auto selectOp = cast<arith::SelectOp>(op);
1448b091419SMatthias Springer Location loc = selectOp.getLoc();
145dec8af70SRiver Riddle
146dec8af70SRiver Riddle // TODO: It would be more efficient to copy the result of the `select` op
147dec8af70SRiver Riddle // instead of its OpOperands. In the worst case, 2 copies are inserted at
148dec8af70SRiver Riddle // the moment (one for each tensor). When copying the op result, only one
149dec8af70SRiver Riddle // copy would be needed.
1505d50f51cSMatthias Springer FailureOr<Value> maybeTrueBuffer =
1515d50f51cSMatthias Springer getBuffer(rewriter, selectOp.getTrueValue(), options);
1525d50f51cSMatthias Springer FailureOr<Value> maybeFalseBuffer =
1535d50f51cSMatthias Springer getBuffer(rewriter, selectOp.getFalseValue(), options);
1545d50f51cSMatthias Springer if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
1555d50f51cSMatthias Springer return failure();
1565d50f51cSMatthias Springer Value trueBuffer = *maybeTrueBuffer;
1575d50f51cSMatthias Springer Value falseBuffer = *maybeFalseBuffer;
158c0b0b6a0SMatthias Springer BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>();
159c0b0b6a0SMatthias Springer BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>();
160c0b0b6a0SMatthias Springer if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt())
161c0b0b6a0SMatthias Springer return op->emitError("inconsistent memory space on true/false operands");
1628b091419SMatthias Springer
1638b091419SMatthias Springer // The "true" and the "false" operands must have the same type. If the
1648b091419SMatthias Springer // buffers have different types, they differ only in their layout map. Cast
1658b091419SMatthias Springer // both of them to the most dynamic MemRef type.
1668b091419SMatthias Springer if (trueBuffer.getType() != falseBuffer.getType()) {
1678b091419SMatthias Springer auto trueType = trueBuffer.getType().cast<MemRefType>();
1688b091419SMatthias Springer int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
16912e41d92SMatthias Springer SmallVector<int64_t> dynamicStrides(trueType.getRank(),
1708b091419SMatthias Springer ShapedType::kDynamicStrideOrOffset);
1718b091419SMatthias Springer AffineMap stridedLayout = makeStridedLinearLayoutMap(
1728b091419SMatthias Springer dynamicStrides, dynamicOffset, op->getContext());
17312e41d92SMatthias Springer auto castedType =
17412e41d92SMatthias Springer MemRefType::get(trueType.getShape(), trueType.getElementType(),
17512e41d92SMatthias Springer stridedLayout, trueType.getMemorySpaceAsInt());
1768b091419SMatthias Springer trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
1778b091419SMatthias Springer falseBuffer =
1788b091419SMatthias Springer rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
1798b091419SMatthias Springer }
1808b091419SMatthias Springer
181dec8af70SRiver Riddle replaceOpWithNewBufferizedOp<arith::SelectOp>(
182dec8af70SRiver Riddle rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
183dec8af70SRiver Riddle return success();
184dec8af70SRiver Riddle }
185dec8af70SRiver Riddle
bufferRelation__anonbafe9d880111::SelectOpInterface186dec8af70SRiver Riddle BufferRelation bufferRelation(Operation *op, OpResult opResult,
1879597b16aSMatthias Springer const AnalysisState &state) const {
188dec8af70SRiver Riddle return BufferRelation::None;
189dec8af70SRiver Riddle }
190dec8af70SRiver Riddle };
191dec8af70SRiver Riddle
192075e3fddSMatthias Springer } // namespace
193075e3fddSMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)194075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels(
195075e3fddSMatthias Springer DialectRegistry ®istry) {
19677eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
19777eee579SRiver Riddle ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
19877eee579SRiver Riddle IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
19977eee579SRiver Riddle SelectOp::attachInterface<SelectOpInterface>(*ctx);
20077eee579SRiver Riddle });
201075e3fddSMatthias Springer }
202