1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2075e3fddSMatthias Springer //
3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6075e3fddSMatthias Springer //
7075e3fddSMatthias Springer //===----------------------------------------------------------------------===//
8075e3fddSMatthias Springer 
9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h"
15075e3fddSMatthias Springer #include "mlir/IR/Operation.h"
16075e3fddSMatthias Springer 
17dec8af70SRiver Riddle using namespace mlir;
18075e3fddSMatthias Springer using namespace mlir::bufferization;
19075e3fddSMatthias Springer 
20075e3fddSMatthias Springer namespace {
21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global.
22075e3fddSMatthias Springer struct ConstantOpInterface
23075e3fddSMatthias Springer     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24075e3fddSMatthias Springer                                                     arith::ConstantOp> {
bufferize__anonbafe9d880111::ConstantOpInterface25075e3fddSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
26b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
27075e3fddSMatthias Springer     auto constantOp = cast<arith::ConstantOp>(op);
28075e3fddSMatthias Springer 
29c0b0b6a0SMatthias Springer     // TODO: Implement memory space for this op. E.g., by adding a memory_space
30c0b0b6a0SMatthias Springer     // attribute to ConstantOp.
31c0b0b6a0SMatthias Springer     if (options.defaultMemorySpace != static_cast<unsigned>(0))
32c0b0b6a0SMatthias Springer       return op->emitError("memory space not implemented yet");
33c0b0b6a0SMatthias Springer 
34075e3fddSMatthias Springer     // Only ranked tensors are supported.
35075e3fddSMatthias Springer     if (!constantOp.getType().isa<RankedTensorType>())
36075e3fddSMatthias Springer       return failure();
37075e3fddSMatthias Springer 
38075e3fddSMatthias Springer     // Only constants inside a module are supported.
39075e3fddSMatthias Springer     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
40075e3fddSMatthias Springer     if (!moduleOp)
41075e3fddSMatthias Springer       return failure();
42075e3fddSMatthias Springer 
43075e3fddSMatthias Springer     // Create global memory segment and replace tensor with memref pointing to
44075e3fddSMatthias Springer     // that memory segment.
45ab47418dSMatthias Springer     FailureOr<memref::GlobalOp> globalOp =
46b55d55ecSMatthias Springer         getGlobalFor(constantOp, options.bufferAlignment);
47ab47418dSMatthias Springer     if (failed(globalOp))
48ab47418dSMatthias Springer       return failure();
496d5fc1e3SKazu Hirata     memref::GlobalOp globalMemref = *globalOp;
50075e3fddSMatthias Springer     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
51*136d746eSJacques Pienaar         rewriter, op, globalMemref.getType(), globalMemref.getName());
52075e3fddSMatthias Springer 
53075e3fddSMatthias Springer     return success();
54075e3fddSMatthias Springer   }
55075e3fddSMatthias Springer 
isWritable__anonbafe9d880111::ConstantOpInterface56075e3fddSMatthias Springer   bool isWritable(Operation *op, Value value,
579597b16aSMatthias Springer                   const AnalysisState &state) const {
58075e3fddSMatthias Springer     // Memory locations returned by memref::GetGlobalOp may not be written to.
59075e3fddSMatthias Springer     assert(value.isa<OpResult>());
60075e3fddSMatthias Springer     return false;
61075e3fddSMatthias Springer   }
62075e3fddSMatthias Springer };
63075e3fddSMatthias Springer 
64075e3fddSMatthias Springer struct IndexCastOpInterface
65075e3fddSMatthias Springer     : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
66075e3fddSMatthias Springer                                                     arith::IndexCastOp> {
bufferizesToMemoryRead__anonbafe9d880111::IndexCastOpInterface67075e3fddSMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
689597b16aSMatthias Springer                               const AnalysisState &state) const {
69075e3fddSMatthias Springer     return false;
70075e3fddSMatthias Springer   }
71075e3fddSMatthias Springer 
bufferizesToMemoryWrite__anonbafe9d880111::IndexCastOpInterface72075e3fddSMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
739597b16aSMatthias Springer                                const AnalysisState &state) const {
74075e3fddSMatthias Springer     return false;
75075e3fddSMatthias Springer   }
76075e3fddSMatthias Springer 
getAliasingOpResult__anonbafe9d880111::IndexCastOpInterface779597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
789597b16aSMatthias Springer                                             const AnalysisState &state) const {
79585a8a32SMatthias Springer     return {op->getResult(0)};
80075e3fddSMatthias Springer   }
81075e3fddSMatthias Springer 
bufferRelation__anonbafe9d880111::IndexCastOpInterface82075e3fddSMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
839597b16aSMatthias Springer                                 const AnalysisState &state) const {
84075e3fddSMatthias Springer     return BufferRelation::Equivalent;
85075e3fddSMatthias Springer   }
86075e3fddSMatthias Springer 
bufferize__anonbafe9d880111::IndexCastOpInterface87075e3fddSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
89075e3fddSMatthias Springer     auto castOp = cast<arith::IndexCastOp>(op);
9012e41d92SMatthias Springer     auto resultTensorType = castOp.getType().cast<TensorType>();
91075e3fddSMatthias Springer 
925d50f51cSMatthias Springer     FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
935d50f51cSMatthias Springer     if (failed(source))
945d50f51cSMatthias Springer       return failure();
955d50f51cSMatthias Springer     auto sourceType = source->getType().cast<BaseMemRefType>();
96075e3fddSMatthias Springer 
97075e3fddSMatthias Springer     // Result type should have same layout and address space as the source type.
9812e41d92SMatthias Springer     BaseMemRefType resultType;
9912e41d92SMatthias Springer     if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
10012e41d92SMatthias Springer       resultType = MemRefType::get(
10112e41d92SMatthias Springer           rankedMemRefType.getShape(), resultTensorType.getElementType(),
10212e41d92SMatthias Springer           rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
10312e41d92SMatthias Springer     } else {
10412e41d92SMatthias Springer       auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
10512e41d92SMatthias Springer       resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
10612e41d92SMatthias Springer                                            unrankedMemrefType.getMemorySpace());
10712e41d92SMatthias Springer     }
108075e3fddSMatthias Springer 
1093c69bc4dSRiver Riddle     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
1105d50f51cSMatthias Springer                                                      *source);
111075e3fddSMatthias Springer     return success();
112075e3fddSMatthias Springer   }
113075e3fddSMatthias Springer };
114075e3fddSMatthias Springer 
115dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands.
116dec8af70SRiver Riddle struct SelectOpInterface
117dec8af70SRiver Riddle     : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
118dec8af70SRiver Riddle                                                     arith::SelectOp> {
bufferizesToMemoryRead__anonbafe9d880111::SelectOpInterface119dec8af70SRiver Riddle   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1209597b16aSMatthias Springer                               const AnalysisState &state) const {
121dec8af70SRiver Riddle     return false;
122dec8af70SRiver Riddle   }
123dec8af70SRiver Riddle 
bufferizesToMemoryWrite__anonbafe9d880111::SelectOpInterface124dec8af70SRiver Riddle   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1259597b16aSMatthias Springer                                const AnalysisState &state) const {
126dec8af70SRiver Riddle     return false;
127dec8af70SRiver Riddle   }
128dec8af70SRiver Riddle 
getAliasingOpResult__anonbafe9d880111::SelectOpInterface1299597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1309597b16aSMatthias Springer                                             const AnalysisState &state) const {
131585a8a32SMatthias Springer     return {op->getOpResult(0) /*result*/};
132dec8af70SRiver Riddle   }
133dec8af70SRiver Riddle 
134dec8af70SRiver Riddle   SmallVector<OpOperand *>
getAliasingOpOperand__anonbafe9d880111::SelectOpInterface135dec8af70SRiver Riddle   getAliasingOpOperand(Operation *op, OpResult opResult,
1369597b16aSMatthias Springer                        const AnalysisState &state) const {
137dec8af70SRiver Riddle     return {&op->getOpOperand(1) /*true_value*/,
138dec8af70SRiver Riddle             &op->getOpOperand(2) /*false_value*/};
139dec8af70SRiver Riddle   }
140dec8af70SRiver Riddle 
bufferize__anonbafe9d880111::SelectOpInterface141dec8af70SRiver Riddle   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
142b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
143dec8af70SRiver Riddle     auto selectOp = cast<arith::SelectOp>(op);
1448b091419SMatthias Springer     Location loc = selectOp.getLoc();
145dec8af70SRiver Riddle 
146dec8af70SRiver Riddle     // TODO: It would be more efficient to copy the result of the `select` op
147dec8af70SRiver Riddle     // instead of its OpOperands. In the worst case, 2 copies are inserted at
148dec8af70SRiver Riddle     // the moment (one for each tensor). When copying the op result, only one
149dec8af70SRiver Riddle     // copy would be needed.
1505d50f51cSMatthias Springer     FailureOr<Value> maybeTrueBuffer =
1515d50f51cSMatthias Springer         getBuffer(rewriter, selectOp.getTrueValue(), options);
1525d50f51cSMatthias Springer     FailureOr<Value> maybeFalseBuffer =
1535d50f51cSMatthias Springer         getBuffer(rewriter, selectOp.getFalseValue(), options);
1545d50f51cSMatthias Springer     if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
1555d50f51cSMatthias Springer       return failure();
1565d50f51cSMatthias Springer     Value trueBuffer = *maybeTrueBuffer;
1575d50f51cSMatthias Springer     Value falseBuffer = *maybeFalseBuffer;
158c0b0b6a0SMatthias Springer     BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>();
159c0b0b6a0SMatthias Springer     BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>();
160c0b0b6a0SMatthias Springer     if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt())
161c0b0b6a0SMatthias Springer       return op->emitError("inconsistent memory space on true/false operands");
1628b091419SMatthias Springer 
1638b091419SMatthias Springer     // The "true" and the "false" operands must have the same type. If the
1648b091419SMatthias Springer     // buffers have different types, they differ only in their layout map. Cast
1658b091419SMatthias Springer     // both of them to the most dynamic MemRef type.
1668b091419SMatthias Springer     if (trueBuffer.getType() != falseBuffer.getType()) {
1678b091419SMatthias Springer       auto trueType = trueBuffer.getType().cast<MemRefType>();
1688b091419SMatthias Springer       int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
16912e41d92SMatthias Springer       SmallVector<int64_t> dynamicStrides(trueType.getRank(),
1708b091419SMatthias Springer                                           ShapedType::kDynamicStrideOrOffset);
1718b091419SMatthias Springer       AffineMap stridedLayout = makeStridedLinearLayoutMap(
1728b091419SMatthias Springer           dynamicStrides, dynamicOffset, op->getContext());
17312e41d92SMatthias Springer       auto castedType =
17412e41d92SMatthias Springer           MemRefType::get(trueType.getShape(), trueType.getElementType(),
17512e41d92SMatthias Springer                           stridedLayout, trueType.getMemorySpaceAsInt());
1768b091419SMatthias Springer       trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
1778b091419SMatthias Springer       falseBuffer =
1788b091419SMatthias Springer           rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
1798b091419SMatthias Springer     }
1808b091419SMatthias Springer 
181dec8af70SRiver Riddle     replaceOpWithNewBufferizedOp<arith::SelectOp>(
182dec8af70SRiver Riddle         rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
183dec8af70SRiver Riddle     return success();
184dec8af70SRiver Riddle   }
185dec8af70SRiver Riddle 
bufferRelation__anonbafe9d880111::SelectOpInterface186dec8af70SRiver Riddle   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1879597b16aSMatthias Springer                                 const AnalysisState &state) const {
188dec8af70SRiver Riddle     return BufferRelation::None;
189dec8af70SRiver Riddle   }
190dec8af70SRiver Riddle };
191dec8af70SRiver Riddle 
192075e3fddSMatthias Springer } // namespace
193075e3fddSMatthias Springer 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)194075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels(
195075e3fddSMatthias Springer     DialectRegistry &registry) {
19677eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
19777eee579SRiver Riddle     ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
19877eee579SRiver Riddle     IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
19977eee579SRiver Riddle     SelectOp::attachInterface<SelectOpInterface>(*ctx);
20077eee579SRiver Riddle   });
201075e3fddSMatthias Springer }
202