1*c66303c2SMatthias Springer //===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===//
2*c66303c2SMatthias Springer //
3*c66303c2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c66303c2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*c66303c2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c66303c2SMatthias Springer //
7*c66303c2SMatthias Springer //===----------------------------------------------------------------------===//
8*c66303c2SMatthias Springer 
9*c66303c2SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10*c66303c2SMatthias Springer 
11*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
13*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14*c66303c2SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
15*c66303c2SMatthias Springer #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16*c66303c2SMatthias Springer 
17*c66303c2SMatthias Springer using namespace mlir;
18*c66303c2SMatthias Springer using namespace mlir::func;
19*c66303c2SMatthias Springer 
20*c66303c2SMatthias Springer namespace mlir {
21*c66303c2SMatthias Springer namespace sparse_tensor {
22*c66303c2SMatthias Springer 
23*c66303c2SMatthias Springer /// Return `true` if one of the given types is a sparse tensor type.
containsSparseTensor(TypeRange types)24*c66303c2SMatthias Springer static bool containsSparseTensor(TypeRange types) {
25*c66303c2SMatthias Springer   for (Type t : types)
26*c66303c2SMatthias Springer     if (getSparseTensorEncoding(t))
27*c66303c2SMatthias Springer       return true;
28*c66303c2SMatthias Springer   return false;
29*c66303c2SMatthias Springer }
30*c66303c2SMatthias Springer 
31*c66303c2SMatthias Springer /// A pass that bufferizes only dense tensor ops and ignores all sparse tensor
32*c66303c2SMatthias Springer /// ops. No buffer copies are inserted. All tensor OpOperands must be
33*c66303c2SMatthias Springer /// inplacable.
34*c66303c2SMatthias Springer class BufferizeDenseOpsPass
35*c66303c2SMatthias Springer     : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
36*c66303c2SMatthias Springer public:
BufferizeDenseOpsPass(const bufferization::OneShotBufferizationOptions & options)37*c66303c2SMatthias Springer   BufferizeDenseOpsPass(
38*c66303c2SMatthias Springer       const bufferization::OneShotBufferizationOptions &options)
39*c66303c2SMatthias Springer       : PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(),
40*c66303c2SMatthias Springer         options(options) {}
41*c66303c2SMatthias Springer 
runOnOperation()42*c66303c2SMatthias Springer   void runOnOperation() override {
43*c66303c2SMatthias Springer     // Disallow all sparse tensor ops, so that only dense tensor ops are
44*c66303c2SMatthias Springer     // bufferized.
45*c66303c2SMatthias Springer     bufferization::OpFilter opFilter;
46*c66303c2SMatthias Springer     opFilter.allowOperation([&](Operation *op) {
47*c66303c2SMatthias Springer       if (containsSparseTensor(TypeRange(op->getResults())) ||
48*c66303c2SMatthias Springer           containsSparseTensor(TypeRange(op->getOperands())))
49*c66303c2SMatthias Springer         return false;
50*c66303c2SMatthias Springer       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
51*c66303c2SMatthias Springer         FunctionType funcType = funcOp.getFunctionType();
52*c66303c2SMatthias Springer         if (containsSparseTensor(funcType.getInputs()) ||
53*c66303c2SMatthias Springer             containsSparseTensor(funcType.getResults()))
54*c66303c2SMatthias Springer           return false;
55*c66303c2SMatthias Springer       }
56*c66303c2SMatthias Springer       return true;
57*c66303c2SMatthias Springer     });
58*c66303c2SMatthias Springer 
59*c66303c2SMatthias Springer     if (failed(bufferization::bufferizeOp(getOperation(), options,
60*c66303c2SMatthias Springer                                           /*copyBeforeWrite=*/false,
61*c66303c2SMatthias Springer                                           &opFilter)))
62*c66303c2SMatthias Springer       signalPassFailure();
63*c66303c2SMatthias Springer   }
64*c66303c2SMatthias Springer 
65*c66303c2SMatthias Springer private:
66*c66303c2SMatthias Springer   bufferization::OneShotBufferizationOptions options;
67*c66303c2SMatthias Springer };
68*c66303c2SMatthias Springer } // namespace sparse_tensor
69*c66303c2SMatthias Springer } // namespace mlir
70*c66303c2SMatthias Springer 
createDenseBufferizationPass(const bufferization::OneShotBufferizationOptions & options)71*c66303c2SMatthias Springer std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
72*c66303c2SMatthias Springer     const bufferization::OneShotBufferizationOptions &options) {
73*c66303c2SMatthias Springer   return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
74*c66303c2SMatthias Springer }
75