1 //===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 
17 using namespace mlir;
18 using namespace mlir::func;
19 
20 namespace mlir {
21 namespace sparse_tensor {
22 
23 /// Return `true` if one of the given types is a sparse tensor type.
24 static bool containsSparseTensor(TypeRange types) {
25   for (Type t : types)
26     if (getSparseTensorEncoding(t))
27       return true;
28   return false;
29 }
30 
31 /// A pass that bufferizes only dense tensor ops and ignores all sparse tensor
32 /// ops. No buffer copies are inserted. All tensor OpOperands must be
33 /// inplacable.
34 class BufferizeDenseOpsPass
35     : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
36 public:
37   BufferizeDenseOpsPass(
38       const bufferization::OneShotBufferizationOptions &options)
39       : PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(),
40         options(options) {}
41 
42   void runOnOperation() override {
43     // Disallow all sparse tensor ops, so that only dense tensor ops are
44     // bufferized.
45     bufferization::OpFilter opFilter;
46     opFilter.allowOperation([&](Operation *op) {
47       if (containsSparseTensor(TypeRange(op->getResults())) ||
48           containsSparseTensor(TypeRange(op->getOperands())))
49         return false;
50       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
51         FunctionType funcType = funcOp.getFunctionType();
52         if (containsSparseTensor(funcType.getInputs()) ||
53             containsSparseTensor(funcType.getResults()))
54           return false;
55       }
56       return true;
57     });
58 
59     if (failed(bufferization::bufferizeOp(getOperation(), options,
60                                           /*copyBeforeWrite=*/false,
61                                           &opFilter)))
62       signalPassFailure();
63   }
64 
65 private:
66   bufferization::OneShotBufferizationOptions options;
67 };
68 } // namespace sparse_tensor
69 } // namespace mlir
70 
71 std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
72     const bufferization::OneShotBufferizationOptions &options) {
73   return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
74 }
75