1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2075e3fddSMatthias Springer //
3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6075e3fddSMatthias Springer //
7075e3fddSMatthias Springer //===----------------------------------------------------------------------===//
8075e3fddSMatthias Springer 
9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h"
15075e3fddSMatthias Springer #include "mlir/IR/Operation.h"
16075e3fddSMatthias Springer 
17dec8af70SRiver Riddle using namespace mlir;
18075e3fddSMatthias Springer using namespace mlir::bufferization;
19075e3fddSMatthias Springer 
20075e3fddSMatthias Springer namespace {
21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global.
22075e3fddSMatthias Springer struct ConstantOpInterface
23075e3fddSMatthias Springer     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
24075e3fddSMatthias Springer                                                     arith::ConstantOp> {
25075e3fddSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
269597b16aSMatthias Springer                           BufferizationState &state) const {
27075e3fddSMatthias Springer     auto constantOp = cast<arith::ConstantOp>(op);
28075e3fddSMatthias Springer 
29075e3fddSMatthias Springer     // Only ranked tensors are supported.
30075e3fddSMatthias Springer     if (!constantOp.getType().isa<RankedTensorType>())
31075e3fddSMatthias Springer       return failure();
32075e3fddSMatthias Springer 
33075e3fddSMatthias Springer     // Only constants inside a module are supported.
34075e3fddSMatthias Springer     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
35075e3fddSMatthias Springer     if (!moduleOp)
36075e3fddSMatthias Springer       return failure();
37075e3fddSMatthias Springer 
38075e3fddSMatthias Springer     // Create global memory segment and replace tensor with memref pointing to
39075e3fddSMatthias Springer     // that memory segment.
40ab47418dSMatthias Springer     FailureOr<memref::GlobalOp> globalOp =
41ab47418dSMatthias Springer         getGlobalFor(constantOp, state.getOptions().bufferAlignment);
42ab47418dSMatthias Springer     if (failed(globalOp))
43ab47418dSMatthias Springer       return failure();
44ab47418dSMatthias Springer     memref::GlobalOp globalMemref = globalOp.getValue();
45075e3fddSMatthias Springer     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
46075e3fddSMatthias Springer         rewriter, op, globalMemref.type(), globalMemref.getName());
47075e3fddSMatthias Springer 
48075e3fddSMatthias Springer     return success();
49075e3fddSMatthias Springer   }
50075e3fddSMatthias Springer 
51075e3fddSMatthias Springer   bool isWritable(Operation *op, Value value,
529597b16aSMatthias Springer                   const AnalysisState &state) const {
53075e3fddSMatthias Springer     // Memory locations returned by memref::GetGlobalOp may not be written to.
54075e3fddSMatthias Springer     assert(value.isa<OpResult>());
55075e3fddSMatthias Springer     return false;
56075e3fddSMatthias Springer   }
57075e3fddSMatthias Springer };
58075e3fddSMatthias Springer 
59075e3fddSMatthias Springer struct IndexCastOpInterface
60075e3fddSMatthias Springer     : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
61075e3fddSMatthias Springer                                                     arith::IndexCastOp> {
62075e3fddSMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
639597b16aSMatthias Springer                               const AnalysisState &state) const {
64075e3fddSMatthias Springer     return false;
65075e3fddSMatthias Springer   }
66075e3fddSMatthias Springer 
67075e3fddSMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
689597b16aSMatthias Springer                                const AnalysisState &state) const {
69075e3fddSMatthias Springer     return false;
70075e3fddSMatthias Springer   }
71075e3fddSMatthias Springer 
729597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
739597b16aSMatthias Springer                                             const AnalysisState &state) const {
74585a8a32SMatthias Springer     return {op->getResult(0)};
75075e3fddSMatthias Springer   }
76075e3fddSMatthias Springer 
77075e3fddSMatthias Springer   BufferRelation bufferRelation(Operation *op, OpResult opResult,
789597b16aSMatthias Springer                                 const AnalysisState &state) const {
79075e3fddSMatthias Springer     return BufferRelation::Equivalent;
80075e3fddSMatthias Springer   }
81075e3fddSMatthias Springer 
82075e3fddSMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
839597b16aSMatthias Springer                           BufferizationState &state) const {
84075e3fddSMatthias Springer     auto castOp = cast<arith::IndexCastOp>(op);
85*12e41d92SMatthias Springer     auto resultTensorType = castOp.getType().cast<TensorType>();
86075e3fddSMatthias Springer 
87075e3fddSMatthias Springer     Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
88075e3fddSMatthias Springer     auto sourceType = source.getType().cast<BaseMemRefType>();
89075e3fddSMatthias Springer 
90075e3fddSMatthias Springer     // Result type should have same layout and address space as the source type.
91*12e41d92SMatthias Springer     BaseMemRefType resultType;
92*12e41d92SMatthias Springer     if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
93*12e41d92SMatthias Springer       resultType = MemRefType::get(
94*12e41d92SMatthias Springer           rankedMemRefType.getShape(), resultTensorType.getElementType(),
95*12e41d92SMatthias Springer           rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
96*12e41d92SMatthias Springer     } else {
97*12e41d92SMatthias Springer       auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
98*12e41d92SMatthias Springer       resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
99*12e41d92SMatthias Springer                                            unrankedMemrefType.getMemorySpace());
100*12e41d92SMatthias Springer     }
101075e3fddSMatthias Springer 
1023c69bc4dSRiver Riddle     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
1033c69bc4dSRiver Riddle                                                      source);
104075e3fddSMatthias Springer     return success();
105075e3fddSMatthias Springer   }
106075e3fddSMatthias Springer };
107075e3fddSMatthias Springer 
108dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands.
109dec8af70SRiver Riddle struct SelectOpInterface
110dec8af70SRiver Riddle     : public BufferizableOpInterface::ExternalModel<SelectOpInterface,
111dec8af70SRiver Riddle                                                     arith::SelectOp> {
112dec8af70SRiver Riddle   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1139597b16aSMatthias Springer                               const AnalysisState &state) const {
114dec8af70SRiver Riddle     return false;
115dec8af70SRiver Riddle   }
116dec8af70SRiver Riddle 
117dec8af70SRiver Riddle   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1189597b16aSMatthias Springer                                const AnalysisState &state) const {
119dec8af70SRiver Riddle     return false;
120dec8af70SRiver Riddle   }
121dec8af70SRiver Riddle 
1229597b16aSMatthias Springer   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1239597b16aSMatthias Springer                                             const AnalysisState &state) const {
124585a8a32SMatthias Springer     return {op->getOpResult(0) /*result*/};
125dec8af70SRiver Riddle   }
126dec8af70SRiver Riddle 
127dec8af70SRiver Riddle   SmallVector<OpOperand *>
128dec8af70SRiver Riddle   getAliasingOpOperand(Operation *op, OpResult opResult,
1299597b16aSMatthias Springer                        const AnalysisState &state) const {
130dec8af70SRiver Riddle     return {&op->getOpOperand(1) /*true_value*/,
131dec8af70SRiver Riddle             &op->getOpOperand(2) /*false_value*/};
132dec8af70SRiver Riddle   }
133dec8af70SRiver Riddle 
134dec8af70SRiver Riddle   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1359597b16aSMatthias Springer                           BufferizationState &state) const {
136dec8af70SRiver Riddle     auto selectOp = cast<arith::SelectOp>(op);
1378b091419SMatthias Springer     Location loc = selectOp.getLoc();
138dec8af70SRiver Riddle 
139dec8af70SRiver Riddle     // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
140dec8af70SRiver Riddle     // TODO: It would be more efficient to copy the result of the `select` op
141dec8af70SRiver Riddle     // instead of its OpOperands. In the worst case, 2 copies are inserted at
142dec8af70SRiver Riddle     // the moment (one for each tensor). When copying the op result, only one
143dec8af70SRiver Riddle     // copy would be needed.
144dec8af70SRiver Riddle     Value trueBuffer =
145dec8af70SRiver Riddle         *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
146dec8af70SRiver Riddle     Value falseBuffer =
147dec8af70SRiver Riddle         *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
1488b091419SMatthias Springer 
1498b091419SMatthias Springer     // The "true" and the "false" operands must have the same type. If the
1508b091419SMatthias Springer     // buffers have different types, they differ only in their layout map. Cast
1518b091419SMatthias Springer     // both of them to the most dynamic MemRef type.
1528b091419SMatthias Springer     if (trueBuffer.getType() != falseBuffer.getType()) {
1538b091419SMatthias Springer       auto trueType = trueBuffer.getType().cast<MemRefType>();
1548b091419SMatthias Springer       int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
155*12e41d92SMatthias Springer       SmallVector<int64_t> dynamicStrides(trueType.getRank(),
1568b091419SMatthias Springer                                           ShapedType::kDynamicStrideOrOffset);
1578b091419SMatthias Springer       AffineMap stridedLayout = makeStridedLinearLayoutMap(
1588b091419SMatthias Springer           dynamicStrides, dynamicOffset, op->getContext());
159*12e41d92SMatthias Springer       auto castedType =
160*12e41d92SMatthias Springer           MemRefType::get(trueType.getShape(), trueType.getElementType(),
161*12e41d92SMatthias Springer                           stridedLayout, trueType.getMemorySpaceAsInt());
1628b091419SMatthias Springer       trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
1638b091419SMatthias Springer       falseBuffer =
1648b091419SMatthias Springer           rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
1658b091419SMatthias Springer     }
1668b091419SMatthias Springer 
167dec8af70SRiver Riddle     replaceOpWithNewBufferizedOp<arith::SelectOp>(
168dec8af70SRiver Riddle         rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
169dec8af70SRiver Riddle     return success();
170dec8af70SRiver Riddle   }
171dec8af70SRiver Riddle 
172dec8af70SRiver Riddle   BufferRelation bufferRelation(Operation *op, OpResult opResult,
1739597b16aSMatthias Springer                                 const AnalysisState &state) const {
174dec8af70SRiver Riddle     return BufferRelation::None;
175dec8af70SRiver Riddle   }
176dec8af70SRiver Riddle };
177dec8af70SRiver Riddle 
178075e3fddSMatthias Springer } // namespace
179075e3fddSMatthias Springer 
180075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels(
181075e3fddSMatthias Springer     DialectRegistry &registry) {
18277eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
18377eee579SRiver Riddle     ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
18477eee579SRiver Riddle     IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
18577eee579SRiver Riddle     SelectOp::attachInterface<SelectOpInterface>(*ctx);
18677eee579SRiver Riddle   });
187075e3fddSMatthias Springer }
188