1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 namespace {
26 
27 //===----------------------------------------------------------------------===//
28 // Passes declaration.
29 //===----------------------------------------------------------------------===//
30 
31 #define GEN_PASS_CLASSES
32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // Passes implementation.
36 //===----------------------------------------------------------------------===//
37 
38 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39 
40   SparsificationPass() = default;
41   SparsificationPass(const SparsificationPass &pass) = default;
42   SparsificationPass(const SparsificationOptions &options) {
43     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
44     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45     vectorLength = options.vectorLength;
46     enableSIMDIndex32 = options.enableSIMDIndex32;
47     enableVLAVectorization = options.enableVLAVectorization;
48   }
49 
50   void runOnOperation() override {
51     auto *ctx = &getContext();
52     RewritePatternSet patterns(ctx);
53     // Translate strategy flags to strategy options.
54     SparsificationOptions options(
55         sparseParallelizationStrategy(parallelization),
56         sparseVectorizationStrategy(vectorization), vectorLength,
57         enableSIMDIndex32, enableVLAVectorization);
58     // Apply rewriting.
59     populateSparsificationPatterns(patterns, options);
60     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
61     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
62   }
63 };
64 
65 class SparseTensorTypeConverter : public TypeConverter {
66 public:
67   SparseTensorTypeConverter() {
68     addConversion([](Type type) { return type; });
69     addConversion(convertSparseTensorTypes);
70   }
71   // Maps each sparse tensor type to an opaque pointer.
72   static Optional<Type> convertSparseTensorTypes(Type type) {
73     if (getSparseTensorEncoding(type) != nullptr)
74       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
75     return llvm::None;
76   }
77 };
78 
79 struct SparseTensorConversionPass
80     : public SparseTensorConversionBase<SparseTensorConversionPass> {
81 
82   SparseTensorConversionPass() = default;
83   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
84   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
85     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
86   }
87 
88   void runOnOperation() override {
89     auto *ctx = &getContext();
90     RewritePatternSet patterns(ctx);
91     SparseTensorTypeConverter converter;
92     ConversionTarget target(*ctx);
93     // Everything in the sparse dialect must go!
94     target.addIllegalDialect<SparseTensorDialect>();
95     // All dynamic rules below accept new function, call, return, and various
96     // tensor and bufferization operations as legal output of the rewriting
97     // provided that all sparse tensor types have been fully rewritten.
98     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
99       return converter.isSignatureLegal(op.getFunctionType());
100     });
101     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
102       return converter.isSignatureLegal(op.getCalleeType());
103     });
104     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
105       return converter.isLegal(op.getOperandTypes());
106     });
107     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
108       return converter.isLegal(op.getOperandTypes());
109     });
110     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
111       return converter.isLegal(op.getSource().getType()) &&
112              converter.isLegal(op.getDest().getType());
113     });
114     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
115         [&](tensor::ExpandShapeOp op) {
116           return converter.isLegal(op.getSrc().getType()) &&
117                  converter.isLegal(op.getResult().getType());
118         });
119     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
120         [&](tensor::CollapseShapeOp op) {
121           return converter.isLegal(op.getSrc().getType()) &&
122                  converter.isLegal(op.getResult().getType());
123         });
124     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
125         [&](bufferization::AllocTensorOp op) {
126           return converter.isLegal(op.getType());
127         });
128     // The following operations and dialects may be introduced by the
129     // rewriting rules, and are therefore marked as legal.
130     target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
131                       complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
132                       linalg::YieldOp, tensor::ExtractOp>();
133     target.addLegalDialect<
134         arith::ArithmeticDialect, bufferization::BufferizationDialect,
135         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
136     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
137         [&](bufferization::AllocTensorOp op) {
138           // Dense tensors are legal, sparse tensors are not.
139           return !static_cast<bool>(op.getType().getEncoding());
140         });
141     // Translate strategy flags to strategy options.
142     SparseTensorConversionOptions options(
143         sparseToSparseConversionStrategy(sparseToSparse));
144     // Populate with rules and apply rewriting rules.
145     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
146                                                                    converter);
147     populateCallOpTypeConversionPattern(patterns, converter);
148     populateSparseTensorConversionPatterns(converter, patterns, options);
149     if (failed(applyPartialConversion(getOperation(), target,
150                                       std::move(patterns))))
151       signalPassFailure();
152   }
153 };
154 
155 } // namespace
156 
157 SparseParallelizationStrategy
158 mlir::sparseParallelizationStrategy(int32_t flag) {
159   switch (flag) {
160   default:
161     return SparseParallelizationStrategy::kNone;
162   case 1:
163     return SparseParallelizationStrategy::kDenseOuterLoop;
164   case 2:
165     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
166   case 3:
167     return SparseParallelizationStrategy::kDenseAnyLoop;
168   case 4:
169     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
170   }
171 }
172 
173 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
174   switch (flag) {
175   default:
176     return SparseVectorizationStrategy::kNone;
177   case 1:
178     return SparseVectorizationStrategy::kDenseInnerLoop;
179   case 2:
180     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
181   }
182 }
183 
184 SparseToSparseConversionStrategy
185 mlir::sparseToSparseConversionStrategy(int32_t flag) {
186   switch (flag) {
187   default:
188     return SparseToSparseConversionStrategy::kAuto;
189   case 1:
190     return SparseToSparseConversionStrategy::kViaCOO;
191   case 2:
192     return SparseToSparseConversionStrategy::kDirect;
193   }
194 }
195 
196 std::unique_ptr<Pass> mlir::createSparsificationPass() {
197   return std::make_unique<SparsificationPass>();
198 }
199 
200 std::unique_ptr<Pass>
201 mlir::createSparsificationPass(const SparsificationOptions &options) {
202   return std::make_unique<SparsificationPass>(options);
203 }
204 
205 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
206   return std::make_unique<SparseTensorConversionPass>();
207 }
208 
209 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
210     const SparseTensorConversionOptions &options) {
211   return std::make_unique<SparseTensorConversionPass>(options);
212 }
213