1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24
25 namespace {
26
27 //===----------------------------------------------------------------------===//
28 // Passes declaration.
29 //===----------------------------------------------------------------------===//
30
31 #define GEN_PASS_CLASSES
32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33
34 //===----------------------------------------------------------------------===//
35 // Passes implementation.
36 //===----------------------------------------------------------------------===//
37
38 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39
40 SparsificationPass() = default;
41 SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass__anon1c79d42c0111::SparsificationPass42 SparsificationPass(const SparsificationOptions &options) {
43 parallelization = static_cast<int32_t>(options.parallelizationStrategy);
44 vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45 vectorLength = options.vectorLength;
46 enableSIMDIndex32 = options.enableSIMDIndex32;
47 enableVLAVectorization = options.enableVLAVectorization;
48 }
49
runOnOperation__anon1c79d42c0111::SparsificationPass50 void runOnOperation() override {
51 auto *ctx = &getContext();
52 // Apply pre-rewriting.
53 RewritePatternSet prePatterns(ctx);
54 populateSparseTensorRewriting(prePatterns);
55 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
56 // Translate strategy flags to strategy options.
57 SparsificationOptions options(
58 sparseParallelizationStrategy(parallelization),
59 sparseVectorizationStrategy(vectorization), vectorLength,
60 enableSIMDIndex32, enableVLAVectorization);
61 // Apply sparsification and vector cleanup rewriting.
62 RewritePatternSet patterns(ctx);
63 populateSparsificationPatterns(patterns, options);
64 vector::populateVectorToVectorCanonicalizationPatterns(patterns);
65 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66 }
67 };
68
69 class SparseTensorTypeConverter : public TypeConverter {
70 public:
SparseTensorTypeConverter()71 SparseTensorTypeConverter() {
72 addConversion([](Type type) { return type; });
73 addConversion(convertSparseTensorTypes);
74 }
75 // Maps each sparse tensor type to an opaque pointer.
convertSparseTensorTypes(Type type)76 static Optional<Type> convertSparseTensorTypes(Type type) {
77 if (getSparseTensorEncoding(type) != nullptr)
78 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
79 return llvm::None;
80 }
81 };
82
83 struct SparseTensorConversionPass
84 : public SparseTensorConversionBase<SparseTensorConversionPass> {
85
86 SparseTensorConversionPass() = default;
87 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass__anon1c79d42c0111::SparseTensorConversionPass88 SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
89 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
90 }
91
runOnOperation__anon1c79d42c0111::SparseTensorConversionPass92 void runOnOperation() override {
93 auto *ctx = &getContext();
94 RewritePatternSet patterns(ctx);
95 SparseTensorTypeConverter converter;
96 ConversionTarget target(*ctx);
97 // Everything in the sparse dialect must go!
98 target.addIllegalDialect<SparseTensorDialect>();
99 // All dynamic rules below accept new function, call, return, and various
100 // tensor and bufferization operations as legal output of the rewriting
101 // provided that all sparse tensor types have been fully rewritten.
102 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
103 return converter.isSignatureLegal(op.getFunctionType());
104 });
105 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
106 return converter.isSignatureLegal(op.getCalleeType());
107 });
108 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
109 return converter.isLegal(op.getOperandTypes());
110 });
111 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
112 return converter.isLegal(op.getOperandTypes());
113 });
114 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
115 return converter.isLegal(op.getSource().getType()) &&
116 converter.isLegal(op.getDest().getType());
117 });
118 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
119 [&](tensor::ExpandShapeOp op) {
120 return converter.isLegal(op.getSrc().getType()) &&
121 converter.isLegal(op.getResult().getType());
122 });
123 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
124 [&](tensor::CollapseShapeOp op) {
125 return converter.isLegal(op.getSrc().getType()) &&
126 converter.isLegal(op.getResult().getType());
127 });
128 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
129 [&](bufferization::AllocTensorOp op) {
130 return converter.isLegal(op.getType());
131 });
132 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
133 [&](bufferization::DeallocTensorOp op) {
134 return converter.isLegal(op.getTensor().getType());
135 });
136 // The following operations and dialects may be introduced by the
137 // rewriting rules, and are therefore marked as legal.
138 target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
139 complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
140 linalg::YieldOp, tensor::ExtractOp>();
141 target.addLegalDialect<
142 arith::ArithmeticDialect, bufferization::BufferizationDialect,
143 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
144 // Translate strategy flags to strategy options.
145 SparseTensorConversionOptions options(
146 sparseToSparseConversionStrategy(sparseToSparse));
147 // Populate with rules and apply rewriting rules.
148 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
149 converter);
150 populateCallOpTypeConversionPattern(patterns, converter);
151 populateSparseTensorConversionPatterns(converter, patterns, options);
152 if (failed(applyPartialConversion(getOperation(), target,
153 std::move(patterns))))
154 signalPassFailure();
155 }
156 };
157
158 } // namespace
159
160 SparseParallelizationStrategy
sparseParallelizationStrategy(int32_t flag)161 mlir::sparseParallelizationStrategy(int32_t flag) {
162 switch (flag) {
163 default:
164 return SparseParallelizationStrategy::kNone;
165 case 1:
166 return SparseParallelizationStrategy::kDenseOuterLoop;
167 case 2:
168 return SparseParallelizationStrategy::kAnyStorageOuterLoop;
169 case 3:
170 return SparseParallelizationStrategy::kDenseAnyLoop;
171 case 4:
172 return SparseParallelizationStrategy::kAnyStorageAnyLoop;
173 }
174 }
175
sparseVectorizationStrategy(int32_t flag)176 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
177 switch (flag) {
178 default:
179 return SparseVectorizationStrategy::kNone;
180 case 1:
181 return SparseVectorizationStrategy::kDenseInnerLoop;
182 case 2:
183 return SparseVectorizationStrategy::kAnyStorageInnerLoop;
184 }
185 }
186
187 SparseToSparseConversionStrategy
sparseToSparseConversionStrategy(int32_t flag)188 mlir::sparseToSparseConversionStrategy(int32_t flag) {
189 switch (flag) {
190 default:
191 return SparseToSparseConversionStrategy::kAuto;
192 case 1:
193 return SparseToSparseConversionStrategy::kViaCOO;
194 case 2:
195 return SparseToSparseConversionStrategy::kDirect;
196 }
197 }
198
createSparsificationPass()199 std::unique_ptr<Pass> mlir::createSparsificationPass() {
200 return std::make_unique<SparsificationPass>();
201 }
202
203 std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions & options)204 mlir::createSparsificationPass(const SparsificationOptions &options) {
205 return std::make_unique<SparsificationPass>(options);
206 }
207
createSparseTensorConversionPass()208 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
209 return std::make_unique<SparseTensorConversionPass>();
210 }
211
createSparseTensorConversionPass(const SparseTensorConversionOptions & options)212 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
213 const SparseTensorConversionOptions &options) {
214 return std::make_unique<SparseTensorConversionPass>(options);
215 }
216