1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 namespace {
26 
27 //===----------------------------------------------------------------------===//
28 // Passes declaration.
29 //===----------------------------------------------------------------------===//
30 
31 #define GEN_PASS_CLASSES
32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // Passes implementation.
36 //===----------------------------------------------------------------------===//
37 
38 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39 
40   SparsificationPass() = default;
41   SparsificationPass(const SparsificationPass &pass) = default;
42   SparsificationPass(const SparsificationOptions &options) {
43     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
44     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45     vectorLength = options.vectorLength;
46     enableSIMDIndex32 = options.enableSIMDIndex32;
47     enableVLAVectorization = options.enableVLAVectorization;
48   }
49 
50   void runOnOperation() override {
51     auto *ctx = &getContext();
52     // Apply pre-rewriting.
53     RewritePatternSet prePatterns(ctx);
54     populateSparseTensorRewriting(prePatterns);
55     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
56     // Translate strategy flags to strategy options.
57     SparsificationOptions options(
58         sparseParallelizationStrategy(parallelization),
59         sparseVectorizationStrategy(vectorization), vectorLength,
60         enableSIMDIndex32, enableVLAVectorization);
61     // Apply sparsification and vector cleanup rewriting.
62     RewritePatternSet patterns(ctx);
63     populateSparsificationPatterns(patterns, options);
64     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
65     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66   }
67 };
68 
69 class SparseTensorTypeConverter : public TypeConverter {
70 public:
71   SparseTensorTypeConverter() {
72     addConversion([](Type type) { return type; });
73     addConversion(convertSparseTensorTypes);
74   }
75   // Maps each sparse tensor type to an opaque pointer.
76   static Optional<Type> convertSparseTensorTypes(Type type) {
77     if (getSparseTensorEncoding(type) != nullptr)
78       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
79     return llvm::None;
80   }
81 };
82 
83 struct SparseTensorConversionPass
84     : public SparseTensorConversionBase<SparseTensorConversionPass> {
85 
86   SparseTensorConversionPass() = default;
87   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
88   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
89     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
90   }
91 
92   void runOnOperation() override {
93     auto *ctx = &getContext();
94     RewritePatternSet patterns(ctx);
95     SparseTensorTypeConverter converter;
96     ConversionTarget target(*ctx);
97     // Everything in the sparse dialect must go!
98     target.addIllegalDialect<SparseTensorDialect>();
99     // All dynamic rules below accept new function, call, return, and various
100     // tensor and bufferization operations as legal output of the rewriting
101     // provided that all sparse tensor types have been fully rewritten.
102     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
103       return converter.isSignatureLegal(op.getFunctionType());
104     });
105     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
106       return converter.isSignatureLegal(op.getCalleeType());
107     });
108     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
109       return converter.isLegal(op.getOperandTypes());
110     });
111     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
112       return converter.isLegal(op.getOperandTypes());
113     });
114     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
115       return converter.isLegal(op.getSource().getType()) &&
116              converter.isLegal(op.getDest().getType());
117     });
118     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
119         [&](tensor::ExpandShapeOp op) {
120           return converter.isLegal(op.getSrc().getType()) &&
121                  converter.isLegal(op.getResult().getType());
122         });
123     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
124         [&](tensor::CollapseShapeOp op) {
125           return converter.isLegal(op.getSrc().getType()) &&
126                  converter.isLegal(op.getResult().getType());
127         });
128     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
129         [&](bufferization::AllocTensorOp op) {
130           return converter.isLegal(op.getType());
131         });
132     // The following operations and dialects may be introduced by the
133     // rewriting rules, and are therefore marked as legal.
134     target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
135                       complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
136                       linalg::YieldOp, tensor::ExtractOp>();
137     target.addLegalDialect<
138         arith::ArithmeticDialect, bufferization::BufferizationDialect,
139         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
140     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
141         [&](bufferization::AllocTensorOp op) {
142           // Dense tensors are legal, sparse tensors are not.
143           return !static_cast<bool>(op.getType().getEncoding());
144         });
145     // Translate strategy flags to strategy options.
146     SparseTensorConversionOptions options(
147         sparseToSparseConversionStrategy(sparseToSparse));
148     // Populate with rules and apply rewriting rules.
149     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
150                                                                    converter);
151     populateCallOpTypeConversionPattern(patterns, converter);
152     populateSparseTensorConversionPatterns(converter, patterns, options);
153     if (failed(applyPartialConversion(getOperation(), target,
154                                       std::move(patterns))))
155       signalPassFailure();
156   }
157 };
158 
159 } // namespace
160 
161 SparseParallelizationStrategy
162 mlir::sparseParallelizationStrategy(int32_t flag) {
163   switch (flag) {
164   default:
165     return SparseParallelizationStrategy::kNone;
166   case 1:
167     return SparseParallelizationStrategy::kDenseOuterLoop;
168   case 2:
169     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
170   case 3:
171     return SparseParallelizationStrategy::kDenseAnyLoop;
172   case 4:
173     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
174   }
175 }
176 
177 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
178   switch (flag) {
179   default:
180     return SparseVectorizationStrategy::kNone;
181   case 1:
182     return SparseVectorizationStrategy::kDenseInnerLoop;
183   case 2:
184     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
185   }
186 }
187 
188 SparseToSparseConversionStrategy
189 mlir::sparseToSparseConversionStrategy(int32_t flag) {
190   switch (flag) {
191   default:
192     return SparseToSparseConversionStrategy::kAuto;
193   case 1:
194     return SparseToSparseConversionStrategy::kViaCOO;
195   case 2:
196     return SparseToSparseConversionStrategy::kDirect;
197   }
198 }
199 
200 std::unique_ptr<Pass> mlir::createSparsificationPass() {
201   return std::make_unique<SparsificationPass>();
202 }
203 
204 std::unique_ptr<Pass>
205 mlir::createSparsificationPass(const SparsificationOptions &options) {
206   return std::make_unique<SparsificationPass>(options);
207 }
208 
209 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
210   return std::make_unique<SparseTensorConversionPass>();
211 }
212 
213 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
214     const SparseTensorConversionOptions &options) {
215   return std::make_unique<SparseTensorConversionPass>(options);
216 }
217