1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
20 
21 namespace {
22 
23 //===----------------------------------------------------------------------===//
24 // Passes declaration.
25 //===----------------------------------------------------------------------===//
26 
27 #define GEN_PASS_CLASSES
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
29 
30 //===----------------------------------------------------------------------===//
31 // Passes implementation.
32 //===----------------------------------------------------------------------===//
33 
34 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
35 
36   SparsificationPass() = default;
37   SparsificationPass(const SparsificationPass &pass) = default;
38   SparsificationPass(const SparsificationOptions &options) {
39     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
40     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
41     vectorLength = options.vectorLength;
42     enableSIMDIndex32 = options.enableSIMDIndex32;
43     enableVLAVectorization = options.enableVLAVectorization;
44   }
45 
46   void runOnOperation() override {
47     auto *ctx = &getContext();
48     RewritePatternSet patterns(ctx);
49     // Translate strategy flags to strategy options.
50     SparsificationOptions options(
51         sparseParallelizationStrategy(parallelization),
52         sparseVectorizationStrategy(vectorization), vectorLength,
53         enableSIMDIndex32, enableVLAVectorization);
54     // Apply rewriting.
55     populateSparsificationPatterns(patterns, options);
56     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
57     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
58   }
59 };
60 
61 class SparseTensorTypeConverter : public TypeConverter {
62 public:
63   SparseTensorTypeConverter() {
64     addConversion([](Type type) { return type; });
65     addConversion(convertSparseTensorTypes);
66   }
67   // Maps each sparse tensor type to an opaque pointer.
68   static Optional<Type> convertSparseTensorTypes(Type type) {
69     if (getSparseTensorEncoding(type) != nullptr)
70       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
71     return llvm::None;
72   }
73 };
74 
75 struct SparseTensorConversionPass
76     : public SparseTensorConversionBase<SparseTensorConversionPass> {
77 
78   SparseTensorConversionPass() = default;
79   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
80   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
81     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
82   }
83 
84   void runOnOperation() override {
85     auto *ctx = &getContext();
86     RewritePatternSet patterns(ctx);
87     SparseTensorTypeConverter converter;
88     ConversionTarget target(*ctx);
89     // Everything in the sparse dialect must go!
90     target.addIllegalDialect<SparseTensorDialect>();
91     // All dynamic rules below accept new function, call, return, and tensor
92     // dim and cast operations as legal output of the rewriting provided that
93     // all sparse tensor types have been fully rewritten.
94     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
95       return converter.isSignatureLegal(op.getFunctionType());
96     });
97     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
98       return converter.isSignatureLegal(op.getCalleeType());
99     });
100     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
101       return converter.isLegal(op.getOperandTypes());
102     });
103     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
104       return converter.isLegal(op.getOperandTypes());
105     });
106     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
107       return converter.isLegal(op.getOperand().getType());
108     });
109     // The following operations and dialects may be introduced by the
110     // rewriting rules, and are therefore marked as legal.
111     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
112                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
113                       tensor::ExtractOp>();
114     target
115         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
116                          memref::MemRefDialect, scf::SCFDialect>();
117     // Translate strategy flags to strategy options.
118     SparseTensorConversionOptions options(
119         sparseToSparseConversionStrategy(sparseToSparse));
120     // Populate with rules and apply rewriting rules.
121     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
122                                                                    converter);
123     populateCallOpTypeConversionPattern(patterns, converter);
124     populateSparseTensorConversionPatterns(converter, patterns, options);
125     if (failed(applyPartialConversion(getOperation(), target,
126                                       std::move(patterns))))
127       signalPassFailure();
128   }
129 };
130 
131 } // namespace
132 
133 SparseParallelizationStrategy
134 mlir::sparseParallelizationStrategy(int32_t flag) {
135   switch (flag) {
136   default:
137     return SparseParallelizationStrategy::kNone;
138   case 1:
139     return SparseParallelizationStrategy::kDenseOuterLoop;
140   case 2:
141     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
142   case 3:
143     return SparseParallelizationStrategy::kDenseAnyLoop;
144   case 4:
145     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
146   }
147 }
148 
149 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
150   switch (flag) {
151   default:
152     return SparseVectorizationStrategy::kNone;
153   case 1:
154     return SparseVectorizationStrategy::kDenseInnerLoop;
155   case 2:
156     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
157   }
158 }
159 
160 SparseToSparseConversionStrategy
161 mlir::sparseToSparseConversionStrategy(int32_t flag) {
162   switch (flag) {
163   default:
164     return SparseToSparseConversionStrategy::kAuto;
165   case 1:
166     return SparseToSparseConversionStrategy::kViaCOO;
167   case 2:
168     return SparseToSparseConversionStrategy::kDirect;
169   }
170 }
171 
172 std::unique_ptr<Pass> mlir::createSparsificationPass() {
173   return std::make_unique<SparsificationPass>();
174 }
175 
176 std::unique_ptr<Pass>
177 mlir::createSparsificationPass(const SparsificationOptions &options) {
178   return std::make_unique<SparsificationPass>(options);
179 }
180 
181 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
182   return std::make_unique<SparseTensorConversionPass>();
183 }
184 
185 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
186     const SparseTensorConversionOptions &options) {
187   return std::make_unique<SparseTensorConversionPass>(options);
188 }
189