1221856f5Swren romano //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik
9eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1228b6d412SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15a2c9d4bbSAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
20a2c9d4bbSAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21a2c9d4bbSAart Bik
22a2c9d4bbSAart Bik using namespace mlir;
2396a23911SAart Bik using namespace mlir::sparse_tensor;
24a2c9d4bbSAart Bik
25a2c9d4bbSAart Bik namespace {
26a2c9d4bbSAart Bik
27a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
28a2c9d4bbSAart Bik // Passes declaration.
29a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
30a2c9d4bbSAart Bik
31a2c9d4bbSAart Bik #define GEN_PASS_CLASSES
32a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33a2c9d4bbSAart Bik
34a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
35a2c9d4bbSAart Bik // Passes implementation.
36a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
37a2c9d4bbSAart Bik
38a2c9d4bbSAart Bik struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39a2c9d4bbSAart Bik
40a2c9d4bbSAart Bik SparsificationPass() = default;
41abb336d2SMehdi Amini SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass__anon1c79d42c0111::SparsificationPass42b85ed4e0Swren romano SparsificationPass(const SparsificationOptions &options) {
434620032eSNick Kreeger parallelization = static_cast<int32_t>(options.parallelizationStrategy);
444620032eSNick Kreeger vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45b85ed4e0Swren romano vectorLength = options.vectorLength;
46b85ed4e0Swren romano enableSIMDIndex32 = options.enableSIMDIndex32;
477783a178SJavier Setoain enableVLAVectorization = options.enableVLAVectorization;
48a2c9d4bbSAart Bik }
49a2c9d4bbSAart Bik
runOnOperation__anon1c79d42c0111::SparsificationPass50a2c9d4bbSAart Bik void runOnOperation() override {
51a2c9d4bbSAart Bik auto *ctx = &getContext();
5228ebb0b6SAart Bik // Apply pre-rewriting.
5328ebb0b6SAart Bik RewritePatternSet prePatterns(ctx);
5428ebb0b6SAart Bik populateSparseTensorRewriting(prePatterns);
5528ebb0b6SAart Bik (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
56a2c9d4bbSAart Bik // Translate strategy flags to strategy options.
574620032eSNick Kreeger SparsificationOptions options(
584620032eSNick Kreeger sparseParallelizationStrategy(parallelization),
594620032eSNick Kreeger sparseVectorizationStrategy(vectorization), vectorLength,
607783a178SJavier Setoain enableSIMDIndex32, enableVLAVectorization);
6128ebb0b6SAart Bik // Apply sparsification and vector cleanup rewriting.
6228ebb0b6SAart Bik RewritePatternSet patterns(ctx);
63a2c9d4bbSAart Bik populateSparsificationPatterns(patterns, options);
64a2c9d4bbSAart Bik vector::populateVectorToVectorCanonicalizationPatterns(patterns);
65a2c9d4bbSAart Bik (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66a2c9d4bbSAart Bik }
67a2c9d4bbSAart Bik };
68a2c9d4bbSAart Bik
6996a23911SAart Bik class SparseTensorTypeConverter : public TypeConverter {
7096a23911SAart Bik public:
SparseTensorTypeConverter()7196a23911SAart Bik SparseTensorTypeConverter() {
7296a23911SAart Bik addConversion([](Type type) { return type; });
7396a23911SAart Bik addConversion(convertSparseTensorTypes);
7496a23911SAart Bik }
7596a23911SAart Bik // Maps each sparse tensor type to an opaque pointer.
convertSparseTensorTypes(Type type)7696a23911SAart Bik static Optional<Type> convertSparseTensorTypes(Type type) {
7796a23911SAart Bik if (getSparseTensorEncoding(type) != nullptr)
7896a23911SAart Bik return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
7996a23911SAart Bik return llvm::None;
8096a23911SAart Bik }
8196a23911SAart Bik };
8296a23911SAart Bik
83a2c9d4bbSAart Bik struct SparseTensorConversionPass
84a2c9d4bbSAart Bik : public SparseTensorConversionBase<SparseTensorConversionPass> {
85c7e24db4Swren romano
86c7e24db4Swren romano SparseTensorConversionPass() = default;
87c7e24db4Swren romano SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass__anon1c79d42c0111::SparseTensorConversionPass88c7e24db4Swren romano SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
89c7e24db4Swren romano sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
90c7e24db4Swren romano }
91c7e24db4Swren romano
runOnOperation__anon1c79d42c0111::SparseTensorConversionPass92a2c9d4bbSAart Bik void runOnOperation() override {
93a2c9d4bbSAart Bik auto *ctx = &getContext();
9496a23911SAart Bik RewritePatternSet patterns(ctx);
9596a23911SAart Bik SparseTensorTypeConverter converter;
96a2c9d4bbSAart Bik ConversionTarget target(*ctx);
971b15160eSAart Bik // Everything in the sparse dialect must go!
981b15160eSAart Bik target.addIllegalDialect<SparseTensorDialect>();
99fde04aeeSAart Bik // All dynamic rules below accept new function, call, return, and various
100fde04aeeSAart Bik // tensor and bufferization operations as legal output of the rewriting
101fde04aeeSAart Bik // provided that all sparse tensor types have been fully rewritten.
10258ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1034a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType());
1044a3460a7SRiver Riddle });
10523aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
10696a23911SAart Bik return converter.isSignatureLegal(op.getCalleeType());
10796a23911SAart Bik });
10823aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
10923aa5a74SRiver Riddle return converter.isLegal(op.getOperandTypes());
11023aa5a74SRiver Riddle });
111236a9080SAart Bik target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
112236a9080SAart Bik return converter.isLegal(op.getOperandTypes());
113236a9080SAart Bik });
1141b15160eSAart Bik target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
115136d746eSJacques Pienaar return converter.isLegal(op.getSource().getType()) &&
116136d746eSJacques Pienaar converter.isLegal(op.getDest().getType());
1176d8e2f1eSAart Bik });
1186d8e2f1eSAart Bik target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
1196d8e2f1eSAart Bik [&](tensor::ExpandShapeOp op) {
120136d746eSJacques Pienaar return converter.isLegal(op.getSrc().getType()) &&
121136d746eSJacques Pienaar converter.isLegal(op.getResult().getType());
1226d8e2f1eSAart Bik });
1236d8e2f1eSAart Bik target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
1246d8e2f1eSAart Bik [&](tensor::CollapseShapeOp op) {
125136d746eSJacques Pienaar return converter.isLegal(op.getSrc().getType()) &&
126136d746eSJacques Pienaar converter.isLegal(op.getResult().getType());
1271b15160eSAart Bik });
128fde04aeeSAart Bik target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
129fde04aeeSAart Bik [&](bufferization::AllocTensorOp op) {
130fde04aeeSAart Bik return converter.isLegal(op.getType());
131fde04aeeSAart Bik });
132*27a431f5SMatthias Springer target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
133*27a431f5SMatthias Springer [&](bufferization::DeallocTensorOp op) {
134*27a431f5SMatthias Springer return converter.isLegal(op.getTensor().getType());
135*27a431f5SMatthias Springer });
136236a9080SAart Bik // The following operations and dialects may be introduced by the
137236a9080SAart Bik // rewriting rules, and are therefore marked as legal.
138c66303c2SMatthias Springer target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
139c66303c2SMatthias Springer complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
140faa00c13SAart Bik linalg::YieldOp, tensor::ExtractOp>();
141faa00c13SAart Bik target.addLegalDialect<
142faa00c13SAart Bik arith::ArithmeticDialect, bufferization::BufferizationDialect,
143faa00c13SAart Bik LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
144c7e24db4Swren romano // Translate strategy flags to strategy options.
145c7e24db4Swren romano SparseTensorConversionOptions options(
146c7e24db4Swren romano sparseToSparseConversionStrategy(sparseToSparse));
147236a9080SAart Bik // Populate with rules and apply rewriting rules.
14858ceae95SRiver Riddle populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1497ceffae1SRiver Riddle converter);
15096a23911SAart Bik populateCallOpTypeConversionPattern(patterns, converter);
151c7e24db4Swren romano populateSparseTensorConversionPatterns(converter, patterns, options);
152a2c9d4bbSAart Bik if (failed(applyPartialConversion(getOperation(), target,
15396a23911SAart Bik std::move(patterns))))
154a2c9d4bbSAart Bik signalPassFailure();
155a2c9d4bbSAart Bik }
156a2c9d4bbSAart Bik };
157a2c9d4bbSAart Bik
158be0a7e9fSMehdi Amini } // namespace
159a2c9d4bbSAart Bik
1604620032eSNick Kreeger SparseParallelizationStrategy
sparseParallelizationStrategy(int32_t flag)1614620032eSNick Kreeger mlir::sparseParallelizationStrategy(int32_t flag) {
1624620032eSNick Kreeger switch (flag) {
1634620032eSNick Kreeger default:
1644620032eSNick Kreeger return SparseParallelizationStrategy::kNone;
1654620032eSNick Kreeger case 1:
1664620032eSNick Kreeger return SparseParallelizationStrategy::kDenseOuterLoop;
1674620032eSNick Kreeger case 2:
1684620032eSNick Kreeger return SparseParallelizationStrategy::kAnyStorageOuterLoop;
1694620032eSNick Kreeger case 3:
1704620032eSNick Kreeger return SparseParallelizationStrategy::kDenseAnyLoop;
1714620032eSNick Kreeger case 4:
1724620032eSNick Kreeger return SparseParallelizationStrategy::kAnyStorageAnyLoop;
1734620032eSNick Kreeger }
1744620032eSNick Kreeger }
1754620032eSNick Kreeger
sparseVectorizationStrategy(int32_t flag)1764620032eSNick Kreeger SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
1774620032eSNick Kreeger switch (flag) {
1784620032eSNick Kreeger default:
1794620032eSNick Kreeger return SparseVectorizationStrategy::kNone;
1804620032eSNick Kreeger case 1:
1814620032eSNick Kreeger return SparseVectorizationStrategy::kDenseInnerLoop;
1824620032eSNick Kreeger case 2:
1834620032eSNick Kreeger return SparseVectorizationStrategy::kAnyStorageInnerLoop;
1844620032eSNick Kreeger }
1854620032eSNick Kreeger }
1864620032eSNick Kreeger
187c7e24db4Swren romano SparseToSparseConversionStrategy
sparseToSparseConversionStrategy(int32_t flag)188c7e24db4Swren romano mlir::sparseToSparseConversionStrategy(int32_t flag) {
189c7e24db4Swren romano switch (flag) {
190c7e24db4Swren romano default:
191c7e24db4Swren romano return SparseToSparseConversionStrategy::kAuto;
192c7e24db4Swren romano case 1:
193c7e24db4Swren romano return SparseToSparseConversionStrategy::kViaCOO;
194c7e24db4Swren romano case 2:
195c7e24db4Swren romano return SparseToSparseConversionStrategy::kDirect;
196c7e24db4Swren romano }
197c7e24db4Swren romano }
198c7e24db4Swren romano
createSparsificationPass()199a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparsificationPass() {
200a2c9d4bbSAart Bik return std::make_unique<SparsificationPass>();
201a2c9d4bbSAart Bik }
202a2c9d4bbSAart Bik
203b85ed4e0Swren romano std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions & options)204b85ed4e0Swren romano mlir::createSparsificationPass(const SparsificationOptions &options) {
205b85ed4e0Swren romano return std::make_unique<SparsificationPass>(options);
206b85ed4e0Swren romano }
207b85ed4e0Swren romano
createSparseTensorConversionPass()208a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
209a2c9d4bbSAart Bik return std::make_unique<SparseTensorConversionPass>();
210a2c9d4bbSAart Bik }
211c7e24db4Swren romano
createSparseTensorConversionPass(const SparseTensorConversionOptions & options)212c7e24db4Swren romano std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
213c7e24db4Swren romano const SparseTensorConversionOptions &options) {
214c7e24db4Swren romano return std::make_unique<SparseTensorConversionPass>(options);
215c7e24db4Swren romano }
216