1*c66303c2SMatthias Springer //===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===// 2*c66303c2SMatthias Springer // 3*c66303c2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*c66303c2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5*c66303c2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*c66303c2SMatthias Springer // 7*c66303c2SMatthias Springer //===----------------------------------------------------------------------===// 8*c66303c2SMatthias Springer 9*c66303c2SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 10*c66303c2SMatthias Springer 11*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 13*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14*c66303c2SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 15*c66303c2SMatthias Springer #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16*c66303c2SMatthias Springer 17*c66303c2SMatthias Springer using namespace mlir; 18*c66303c2SMatthias Springer using namespace mlir::func; 19*c66303c2SMatthias Springer 20*c66303c2SMatthias Springer namespace mlir { 21*c66303c2SMatthias Springer namespace sparse_tensor { 22*c66303c2SMatthias Springer 23*c66303c2SMatthias Springer /// Return `true` if one of the given types is a sparse tensor type. containsSparseTensor(TypeRange types)24*c66303c2SMatthias Springerstatic bool containsSparseTensor(TypeRange types) { 25*c66303c2SMatthias Springer for (Type t : types) 26*c66303c2SMatthias Springer if (getSparseTensorEncoding(t)) 27*c66303c2SMatthias Springer return true; 28*c66303c2SMatthias Springer return false; 29*c66303c2SMatthias Springer } 30*c66303c2SMatthias Springer 31*c66303c2SMatthias Springer /// A pass that bufferizes only dense tensor ops and ignores all sparse tensor 32*c66303c2SMatthias Springer /// ops. No buffer copies are inserted. All tensor OpOperands must be 33*c66303c2SMatthias Springer /// inplacable. 34*c66303c2SMatthias Springer class BufferizeDenseOpsPass 35*c66303c2SMatthias Springer : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> { 36*c66303c2SMatthias Springer public: BufferizeDenseOpsPass(const bufferization::OneShotBufferizationOptions & options)37*c66303c2SMatthias Springer BufferizeDenseOpsPass( 38*c66303c2SMatthias Springer const bufferization::OneShotBufferizationOptions &options) 39*c66303c2SMatthias Springer : PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(), 40*c66303c2SMatthias Springer options(options) {} 41*c66303c2SMatthias Springer runOnOperation()42*c66303c2SMatthias Springer void runOnOperation() override { 43*c66303c2SMatthias Springer // Disallow all sparse tensor ops, so that only dense tensor ops are 44*c66303c2SMatthias Springer // bufferized. 45*c66303c2SMatthias Springer bufferization::OpFilter opFilter; 46*c66303c2SMatthias Springer opFilter.allowOperation([&](Operation *op) { 47*c66303c2SMatthias Springer if (containsSparseTensor(TypeRange(op->getResults())) || 48*c66303c2SMatthias Springer containsSparseTensor(TypeRange(op->getOperands()))) 49*c66303c2SMatthias Springer return false; 50*c66303c2SMatthias Springer if (auto funcOp = dyn_cast<func::FuncOp>(op)) { 51*c66303c2SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 52*c66303c2SMatthias Springer if (containsSparseTensor(funcType.getInputs()) || 53*c66303c2SMatthias Springer containsSparseTensor(funcType.getResults())) 54*c66303c2SMatthias Springer return false; 55*c66303c2SMatthias Springer } 56*c66303c2SMatthias Springer return true; 57*c66303c2SMatthias Springer }); 58*c66303c2SMatthias Springer 59*c66303c2SMatthias Springer if (failed(bufferization::bufferizeOp(getOperation(), options, 60*c66303c2SMatthias Springer /*copyBeforeWrite=*/false, 61*c66303c2SMatthias Springer &opFilter))) 62*c66303c2SMatthias Springer signalPassFailure(); 63*c66303c2SMatthias Springer } 64*c66303c2SMatthias Springer 65*c66303c2SMatthias Springer private: 66*c66303c2SMatthias Springer bufferization::OneShotBufferizationOptions options; 67*c66303c2SMatthias Springer }; 68*c66303c2SMatthias Springer } // namespace sparse_tensor 69*c66303c2SMatthias Springer } // namespace mlir 70*c66303c2SMatthias Springer createDenseBufferizationPass(const bufferization::OneShotBufferizationOptions & options)71*c66303c2SMatthias Springerstd::unique_ptr<Pass> mlir::createDenseBufferizationPass( 72*c66303c2SMatthias Springer const bufferization::OneShotBufferizationOptions &options) { 73*c66303c2SMatthias Springer return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options); 74*c66303c2SMatthias Springer } 75