1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
20 
21 namespace {
22 
23 //===----------------------------------------------------------------------===//
24 // Passes declaration.
25 //===----------------------------------------------------------------------===//
26 
27 #define GEN_PASS_CLASSES
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
29 
30 //===----------------------------------------------------------------------===//
31 // Passes implementation.
32 //===----------------------------------------------------------------------===//
33 
34 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
35 
36   SparsificationPass() = default;
37   SparsificationPass(const SparsificationPass &pass) = default;
38   SparsificationPass(const SparsificationOptions &options) {
39     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
40     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
41     vectorLength = options.vectorLength;
42     enableSIMDIndex32 = options.enableSIMDIndex32;
43   }
44 
45   void runOnOperation() override {
46     auto *ctx = &getContext();
47     RewritePatternSet patterns(ctx);
48     // Translate strategy flags to strategy options.
49     SparsificationOptions options(
50         sparseParallelizationStrategy(parallelization),
51         sparseVectorizationStrategy(vectorization), vectorLength,
52         enableSIMDIndex32);
53     // Apply rewriting.
54     populateSparsificationPatterns(patterns, options);
55     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
57   }
58 };
59 
60 class SparseTensorTypeConverter : public TypeConverter {
61 public:
62   SparseTensorTypeConverter() {
63     addConversion([](Type type) { return type; });
64     addConversion(convertSparseTensorTypes);
65   }
66   // Maps each sparse tensor type to an opaque pointer.
67   static Optional<Type> convertSparseTensorTypes(Type type) {
68     if (getSparseTensorEncoding(type) != nullptr)
69       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
70     return llvm::None;
71   }
72 };
73 
74 struct SparseTensorConversionPass
75     : public SparseTensorConversionBase<SparseTensorConversionPass> {
76 
77   SparseTensorConversionPass() = default;
78   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
79   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
80     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
81   }
82 
83   void runOnOperation() override {
84     auto *ctx = &getContext();
85     RewritePatternSet patterns(ctx);
86     SparseTensorTypeConverter converter;
87     ConversionTarget target(*ctx);
88     // Everything in the sparse dialect must go!
89     target.addIllegalDialect<SparseTensorDialect>();
90     // All dynamic rules below accept new function, call, return, and tensor
91     // dim and cast operations as legal output of the rewriting provided that
92     // all sparse tensor types have been fully rewritten.
93     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
94       return converter.isSignatureLegal(op.getFunctionType());
95     });
96     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
97       return converter.isSignatureLegal(op.getCalleeType());
98     });
99     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
100       return converter.isLegal(op.getOperandTypes());
101     });
102     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
103       return converter.isLegal(op.getOperandTypes());
104     });
105     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
106       return converter.isLegal(op.getOperand().getType());
107     });
108     // The following operations and dialects may be introduced by the
109     // rewriting rules, and are therefore marked as legal.
110     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
111                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
112                       tensor::ExtractOp>();
113     target
114         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
115                          memref::MemRefDialect, scf::SCFDialect>();
116     // Translate strategy flags to strategy options.
117     SparseTensorConversionOptions options(
118         sparseToSparseConversionStrategy(sparseToSparse));
119     // Populate with rules and apply rewriting rules.
120     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
121                                                              converter);
122     populateCallOpTypeConversionPattern(patterns, converter);
123     populateSparseTensorConversionPatterns(converter, patterns, options);
124     if (failed(applyPartialConversion(getOperation(), target,
125                                       std::move(patterns))))
126       signalPassFailure();
127   }
128 };
129 
130 } // namespace
131 
132 SparseParallelizationStrategy
133 mlir::sparseParallelizationStrategy(int32_t flag) {
134   switch (flag) {
135   default:
136     return SparseParallelizationStrategy::kNone;
137   case 1:
138     return SparseParallelizationStrategy::kDenseOuterLoop;
139   case 2:
140     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
141   case 3:
142     return SparseParallelizationStrategy::kDenseAnyLoop;
143   case 4:
144     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
145   }
146 }
147 
148 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
149   switch (flag) {
150   default:
151     return SparseVectorizationStrategy::kNone;
152   case 1:
153     return SparseVectorizationStrategy::kDenseInnerLoop;
154   case 2:
155     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
156   }
157 }
158 
159 SparseToSparseConversionStrategy
160 mlir::sparseToSparseConversionStrategy(int32_t flag) {
161   switch (flag) {
162   default:
163     return SparseToSparseConversionStrategy::kAuto;
164   case 1:
165     return SparseToSparseConversionStrategy::kViaCOO;
166   case 2:
167     return SparseToSparseConversionStrategy::kDirect;
168   }
169 }
170 
171 std::unique_ptr<Pass> mlir::createSparsificationPass() {
172   return std::make_unique<SparsificationPass>();
173 }
174 
175 std::unique_ptr<Pass>
176 mlir::createSparsificationPass(const SparsificationOptions &options) {
177   return std::make_unique<SparsificationPass>(options);
178 }
179 
180 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
181   return std::make_unique<SparseTensorConversionPass>();
182 }
183 
184 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
185     const SparseTensorConversionOptions &options) {
186   return std::make_unique<SparseTensorConversionPass>(options);
187 }
188