1221856f5Swren romano //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik 
9eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1228b6d412SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15a2c9d4bbSAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
20a2c9d4bbSAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21a2c9d4bbSAart Bik 
22a2c9d4bbSAart Bik using namespace mlir;
2396a23911SAart Bik using namespace mlir::sparse_tensor;
24a2c9d4bbSAart Bik 
25a2c9d4bbSAart Bik namespace {
26a2c9d4bbSAart Bik 
27a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
28a2c9d4bbSAart Bik // Passes declaration.
29a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
30a2c9d4bbSAart Bik 
31a2c9d4bbSAart Bik #define GEN_PASS_CLASSES
32a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33a2c9d4bbSAart Bik 
34a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
35a2c9d4bbSAart Bik // Passes implementation.
36a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
37a2c9d4bbSAart Bik 
38a2c9d4bbSAart Bik struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39a2c9d4bbSAart Bik 
40a2c9d4bbSAart Bik   SparsificationPass() = default;
41abb336d2SMehdi Amini   SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass__anon1c79d42c0111::SparsificationPass42b85ed4e0Swren romano   SparsificationPass(const SparsificationOptions &options) {
434620032eSNick Kreeger     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
444620032eSNick Kreeger     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45b85ed4e0Swren romano     vectorLength = options.vectorLength;
46b85ed4e0Swren romano     enableSIMDIndex32 = options.enableSIMDIndex32;
477783a178SJavier Setoain     enableVLAVectorization = options.enableVLAVectorization;
48a2c9d4bbSAart Bik   }
49a2c9d4bbSAart Bik 
runOnOperation__anon1c79d42c0111::SparsificationPass50a2c9d4bbSAart Bik   void runOnOperation() override {
51a2c9d4bbSAart Bik     auto *ctx = &getContext();
5228ebb0b6SAart Bik     // Apply pre-rewriting.
5328ebb0b6SAart Bik     RewritePatternSet prePatterns(ctx);
5428ebb0b6SAart Bik     populateSparseTensorRewriting(prePatterns);
5528ebb0b6SAart Bik     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
56a2c9d4bbSAart Bik     // Translate strategy flags to strategy options.
574620032eSNick Kreeger     SparsificationOptions options(
584620032eSNick Kreeger         sparseParallelizationStrategy(parallelization),
594620032eSNick Kreeger         sparseVectorizationStrategy(vectorization), vectorLength,
607783a178SJavier Setoain         enableSIMDIndex32, enableVLAVectorization);
6128ebb0b6SAart Bik     // Apply sparsification and vector cleanup rewriting.
6228ebb0b6SAart Bik     RewritePatternSet patterns(ctx);
63a2c9d4bbSAart Bik     populateSparsificationPatterns(patterns, options);
64a2c9d4bbSAart Bik     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
65a2c9d4bbSAart Bik     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66a2c9d4bbSAart Bik   }
67a2c9d4bbSAart Bik };
68a2c9d4bbSAart Bik 
6996a23911SAart Bik class SparseTensorTypeConverter : public TypeConverter {
7096a23911SAart Bik public:
SparseTensorTypeConverter()7196a23911SAart Bik   SparseTensorTypeConverter() {
7296a23911SAart Bik     addConversion([](Type type) { return type; });
7396a23911SAart Bik     addConversion(convertSparseTensorTypes);
7496a23911SAart Bik   }
7596a23911SAart Bik   // Maps each sparse tensor type to an opaque pointer.
convertSparseTensorTypes(Type type)7696a23911SAart Bik   static Optional<Type> convertSparseTensorTypes(Type type) {
7796a23911SAart Bik     if (getSparseTensorEncoding(type) != nullptr)
7896a23911SAart Bik       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
7996a23911SAart Bik     return llvm::None;
8096a23911SAart Bik   }
8196a23911SAart Bik };
8296a23911SAart Bik 
83a2c9d4bbSAart Bik struct SparseTensorConversionPass
84a2c9d4bbSAart Bik     : public SparseTensorConversionBase<SparseTensorConversionPass> {
85c7e24db4Swren romano 
86c7e24db4Swren romano   SparseTensorConversionPass() = default;
87c7e24db4Swren romano   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass__anon1c79d42c0111::SparseTensorConversionPass88c7e24db4Swren romano   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
89c7e24db4Swren romano     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
90c7e24db4Swren romano   }
91c7e24db4Swren romano 
runOnOperation__anon1c79d42c0111::SparseTensorConversionPass92a2c9d4bbSAart Bik   void runOnOperation() override {
93a2c9d4bbSAart Bik     auto *ctx = &getContext();
9496a23911SAart Bik     RewritePatternSet patterns(ctx);
9596a23911SAart Bik     SparseTensorTypeConverter converter;
96a2c9d4bbSAart Bik     ConversionTarget target(*ctx);
971b15160eSAart Bik     // Everything in the sparse dialect must go!
981b15160eSAart Bik     target.addIllegalDialect<SparseTensorDialect>();
99fde04aeeSAart Bik     // All dynamic rules below accept new function, call, return, and various
100fde04aeeSAart Bik     // tensor and bufferization operations as legal output of the rewriting
101fde04aeeSAart Bik     // provided that all sparse tensor types have been fully rewritten.
10258ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1034a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType());
1044a3460a7SRiver Riddle     });
10523aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
10696a23911SAart Bik       return converter.isSignatureLegal(op.getCalleeType());
10796a23911SAart Bik     });
10823aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
10923aa5a74SRiver Riddle       return converter.isLegal(op.getOperandTypes());
11023aa5a74SRiver Riddle     });
111236a9080SAart Bik     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
112236a9080SAart Bik       return converter.isLegal(op.getOperandTypes());
113236a9080SAart Bik     });
1141b15160eSAart Bik     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
115136d746eSJacques Pienaar       return converter.isLegal(op.getSource().getType()) &&
116136d746eSJacques Pienaar              converter.isLegal(op.getDest().getType());
1176d8e2f1eSAart Bik     });
1186d8e2f1eSAart Bik     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
1196d8e2f1eSAart Bik         [&](tensor::ExpandShapeOp op) {
120136d746eSJacques Pienaar           return converter.isLegal(op.getSrc().getType()) &&
121136d746eSJacques Pienaar                  converter.isLegal(op.getResult().getType());
1226d8e2f1eSAart Bik         });
1236d8e2f1eSAart Bik     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
1246d8e2f1eSAart Bik         [&](tensor::CollapseShapeOp op) {
125136d746eSJacques Pienaar           return converter.isLegal(op.getSrc().getType()) &&
126136d746eSJacques Pienaar                  converter.isLegal(op.getResult().getType());
1271b15160eSAart Bik         });
128fde04aeeSAart Bik     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
129fde04aeeSAart Bik         [&](bufferization::AllocTensorOp op) {
130fde04aeeSAart Bik           return converter.isLegal(op.getType());
131fde04aeeSAart Bik         });
132*27a431f5SMatthias Springer     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
133*27a431f5SMatthias Springer         [&](bufferization::DeallocTensorOp op) {
134*27a431f5SMatthias Springer           return converter.isLegal(op.getTensor().getType());
135*27a431f5SMatthias Springer         });
136236a9080SAart Bik     // The following operations and dialects may be introduced by the
137236a9080SAart Bik     // rewriting rules, and are therefore marked as legal.
138c66303c2SMatthias Springer     target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
139c66303c2SMatthias Springer                       complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
140faa00c13SAart Bik                       linalg::YieldOp, tensor::ExtractOp>();
141faa00c13SAart Bik     target.addLegalDialect<
142faa00c13SAart Bik         arith::ArithmeticDialect, bufferization::BufferizationDialect,
143faa00c13SAart Bik         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
144c7e24db4Swren romano     // Translate strategy flags to strategy options.
145c7e24db4Swren romano     SparseTensorConversionOptions options(
146c7e24db4Swren romano         sparseToSparseConversionStrategy(sparseToSparse));
147236a9080SAart Bik     // Populate with rules and apply rewriting rules.
14858ceae95SRiver Riddle     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1497ceffae1SRiver Riddle                                                                    converter);
15096a23911SAart Bik     populateCallOpTypeConversionPattern(patterns, converter);
151c7e24db4Swren romano     populateSparseTensorConversionPatterns(converter, patterns, options);
152a2c9d4bbSAart Bik     if (failed(applyPartialConversion(getOperation(), target,
15396a23911SAart Bik                                       std::move(patterns))))
154a2c9d4bbSAart Bik       signalPassFailure();
155a2c9d4bbSAart Bik   }
156a2c9d4bbSAart Bik };
157a2c9d4bbSAart Bik 
158be0a7e9fSMehdi Amini } // namespace
159a2c9d4bbSAart Bik 
1604620032eSNick Kreeger SparseParallelizationStrategy
sparseParallelizationStrategy(int32_t flag)1614620032eSNick Kreeger mlir::sparseParallelizationStrategy(int32_t flag) {
1624620032eSNick Kreeger   switch (flag) {
1634620032eSNick Kreeger   default:
1644620032eSNick Kreeger     return SparseParallelizationStrategy::kNone;
1654620032eSNick Kreeger   case 1:
1664620032eSNick Kreeger     return SparseParallelizationStrategy::kDenseOuterLoop;
1674620032eSNick Kreeger   case 2:
1684620032eSNick Kreeger     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
1694620032eSNick Kreeger   case 3:
1704620032eSNick Kreeger     return SparseParallelizationStrategy::kDenseAnyLoop;
1714620032eSNick Kreeger   case 4:
1724620032eSNick Kreeger     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
1734620032eSNick Kreeger   }
1744620032eSNick Kreeger }
1754620032eSNick Kreeger 
sparseVectorizationStrategy(int32_t flag)1764620032eSNick Kreeger SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
1774620032eSNick Kreeger   switch (flag) {
1784620032eSNick Kreeger   default:
1794620032eSNick Kreeger     return SparseVectorizationStrategy::kNone;
1804620032eSNick Kreeger   case 1:
1814620032eSNick Kreeger     return SparseVectorizationStrategy::kDenseInnerLoop;
1824620032eSNick Kreeger   case 2:
1834620032eSNick Kreeger     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
1844620032eSNick Kreeger   }
1854620032eSNick Kreeger }
1864620032eSNick Kreeger 
187c7e24db4Swren romano SparseToSparseConversionStrategy
sparseToSparseConversionStrategy(int32_t flag)188c7e24db4Swren romano mlir::sparseToSparseConversionStrategy(int32_t flag) {
189c7e24db4Swren romano   switch (flag) {
190c7e24db4Swren romano   default:
191c7e24db4Swren romano     return SparseToSparseConversionStrategy::kAuto;
192c7e24db4Swren romano   case 1:
193c7e24db4Swren romano     return SparseToSparseConversionStrategy::kViaCOO;
194c7e24db4Swren romano   case 2:
195c7e24db4Swren romano     return SparseToSparseConversionStrategy::kDirect;
196c7e24db4Swren romano   }
197c7e24db4Swren romano }
198c7e24db4Swren romano 
createSparsificationPass()199a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparsificationPass() {
200a2c9d4bbSAart Bik   return std::make_unique<SparsificationPass>();
201a2c9d4bbSAart Bik }
202a2c9d4bbSAart Bik 
203b85ed4e0Swren romano std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions & options)204b85ed4e0Swren romano mlir::createSparsificationPass(const SparsificationOptions &options) {
205b85ed4e0Swren romano   return std::make_unique<SparsificationPass>(options);
206b85ed4e0Swren romano }
207b85ed4e0Swren romano 
createSparseTensorConversionPass()208a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
209a2c9d4bbSAart Bik   return std::make_unique<SparseTensorConversionPass>();
210a2c9d4bbSAart Bik }
211c7e24db4Swren romano 
createSparseTensorConversionPass(const SparseTensorConversionOptions & options)212c7e24db4Swren romano std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
213c7e24db4Swren romano     const SparseTensorConversionOptions &options) {
214c7e24db4Swren romano   return std::make_unique<SparseTensorConversionPass>(options);
215c7e24db4Swren romano }
216