1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir::bufferization;
18 
19 namespace mlir {
20 namespace arith {
21 namespace {
22 
23 /// Bufferization of arith.constant. Replace with memref.get_global.
24 struct ConstantOpInterface
25     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
26                                                     arith::ConstantOp> {
27   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
28                           const BufferizationState &state) const {
29     auto constantOp = cast<arith::ConstantOp>(op);
30 
31     // Only ranked tensors are supported.
32     if (!constantOp.getType().isa<RankedTensorType>())
33       return failure();
34 
35     // Only constants inside a module are supported.
36     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
37     if (!moduleOp)
38       return failure();
39 
40     // Create global memory segment and replace tensor with memref pointing to
41     // that memory segment.
42     FailureOr<memref::GlobalOp> globalOp =
43         getGlobalFor(constantOp, state.getOptions().bufferAlignment);
44     if (failed(globalOp))
45       return failure();
46     memref::GlobalOp globalMemref = globalOp.getValue();
47     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
48         rewriter, op, globalMemref.type(), globalMemref.getName());
49 
50     return success();
51   }
52 
53   bool isWritable(Operation *op, Value value,
54                   const BufferizationState &state) const {
55     // Memory locations returned by memref::GetGlobalOp may not be written to.
56     assert(value.isa<OpResult>());
57     return false;
58   }
59 };
60 
61 struct IndexCastOpInterface
62     : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
63                                                     arith::IndexCastOp> {
64   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
65                               const BufferizationState &state) const {
66     return false;
67   }
68 
69   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
70                                const BufferizationState &state) const {
71     return false;
72   }
73 
74   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
75                                const BufferizationState &state) const {
76     return op->getResult(0);
77   }
78 
79   BufferRelation bufferRelation(Operation *op, OpResult opResult,
80                                 const BufferizationState &state) const {
81     return BufferRelation::Equivalent;
82   }
83 
84   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
85                           const BufferizationState &state) const {
86     auto castOp = cast<arith::IndexCastOp>(op);
87 
88     Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
89     auto sourceType = source.getType().cast<BaseMemRefType>();
90 
91     // Result type should have same layout and address space as the source type.
92     MemRefLayoutAttrInterface layout = {};
93     if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
94       layout = rankedMemRefType.getLayout();
95     Type resultType =
96         getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
97                       layout, sourceType.getMemorySpace());
98 
99     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
100                                                      resultType);
101     return success();
102   }
103 };
104 
105 } // namespace
106 } // namespace arith
107 } // namespace mlir
108 
109 void mlir::arith::registerBufferizableOpInterfaceExternalModels(
110     DialectRegistry &registry) {
111   registry.addOpInterface<ConstantOp, ConstantOpInterface>();
112   registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
113 }
114