1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These BufferizableOpInterface implementations provide analysis-related
10 // interface methods only. They are getting bufferized by the
11 // SparseTensorConversion pass.
12 
13 #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
14 
15 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir::bufferization;
23 using namespace mlir::sparse_tensor;
24 
25 namespace mlir {
26 namespace sparse_tensor {
27 namespace {
28 
29 struct ConvertOpInterface
30     : public BufferizableOpInterface::ExternalModel<ConvertOpInterface,
31                                                     sparse_tensor::ConvertOp> {
32   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
33     // ConvertOps may allocate. (Unless they convert between two identical
34     // types, then they fold away.)
35     return true;
36   }
37 
38   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
39                               const AnalysisState &state) const {
40     return true;
41   }
42 
43   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
44                                const AnalysisState &state) const {
45     return false;
46   }
47 
48   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
49                                             const AnalysisState &state) const {
50     return {};
51   }
52 
53   bool isWritable(Operation *op, Value value,
54                   const AnalysisState &state) const {
55     return true;
56   }
57 };
58 
59 struct LoadOpInterface
60     : public BufferizableOpInterface::ExternalModel<LoadOpInterface,
61                                                     sparse_tensor::LoadOp> {
62   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
63                               const AnalysisState &state) const {
64     return false;
65   }
66 
67   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
68                                const AnalysisState &state) const {
69     return false;
70   }
71 
72   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
73                                             const AnalysisState &state) const {
74     return {op->getOpResult(0)};
75   }
76 
77   BufferRelation bufferRelation(Operation *op, OpResult opResult,
78                                 const AnalysisState &state) const {
79     return BufferRelation::Equivalent;
80   }
81 };
82 
83 struct NewOpInterface
84     : public BufferizableOpInterface::ExternalModel<NewOpInterface,
85                                                     sparse_tensor::NewOp> {
86   bool isMemoryWrite(Operation *op, OpResult opResult,
87                      const AnalysisState &state) const {
88     // NewOps allocate but do not write.
89     return false;
90   }
91 
92   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
93     return true;
94   }
95 };
96 
97 } // namespace
98 } // namespace sparse_tensor
99 } // namespace mlir
100 
101 void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
102     DialectRegistry &registry) {
103   registry.addExtension(
104       +[](MLIRContext *ctx, sparse_tensor::SparseTensorDialect *dialect) {
105         sparse_tensor::ConvertOp::attachInterface<ConvertOpInterface>(*ctx);
106         sparse_tensor::LoadOp::attachInterface<LoadOpInterface>(*ctx);
107         sparse_tensor::NewOp::attachInterface<NewOpInterface>(*ctx);
108       });
109 }
110