1 //===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 17 using namespace mlir; 18 using namespace mlir::func; 19 20 namespace mlir { 21 namespace sparse_tensor { 22 23 /// Return `true` if one of the given types is a sparse tensor type. 24 static bool containsSparseTensor(TypeRange types) { 25 for (Type t : types) 26 if (getSparseTensorEncoding(t)) 27 return true; 28 return false; 29 } 30 31 /// A pass that bufferizes only dense tensor ops and ignores all sparse tensor 32 /// ops. No buffer copies are inserted. All tensor OpOperands must be 33 /// inplacable. 34 class BufferizeDenseOpsPass 35 : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> { 36 public: 37 BufferizeDenseOpsPass( 38 const bufferization::OneShotBufferizationOptions &options) 39 : PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(), 40 options(options) {} 41 42 void runOnOperation() override { 43 // Disallow all sparse tensor ops, so that only dense tensor ops are 44 // bufferized. 45 bufferization::OpFilter opFilter; 46 opFilter.allowOperation([&](Operation *op) { 47 if (containsSparseTensor(TypeRange(op->getResults())) || 48 containsSparseTensor(TypeRange(op->getOperands()))) 49 return false; 50 if (auto funcOp = dyn_cast<func::FuncOp>(op)) { 51 FunctionType funcType = funcOp.getFunctionType(); 52 if (containsSparseTensor(funcType.getInputs()) || 53 containsSparseTensor(funcType.getResults())) 54 return false; 55 } 56 return true; 57 }); 58 59 if (failed(bufferization::bufferizeOp(getOperation(), options, 60 /*copyBeforeWrite=*/false, 61 &opFilter))) 62 signalPassFailure(); 63 } 64 65 private: 66 bufferization::OneShotBufferizationOptions options; 67 }; 68 } // namespace sparse_tensor 69 } // namespace mlir 70 71 std::unique_ptr<Pass> mlir::createDenseBufferizationPass( 72 const bufferization::OneShotBufferizationOptions &options) { 73 return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options); 74 } 75