1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 namespace {
26 
27 //===----------------------------------------------------------------------===//
28 // Passes declaration.
29 //===----------------------------------------------------------------------===//
30 
31 #define GEN_PASS_CLASSES
32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // Passes implementation.
36 //===----------------------------------------------------------------------===//
37 
38 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
39 
40   SparsificationPass() = default;
41   SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass__anon1c79d42c0111::SparsificationPass42   SparsificationPass(const SparsificationOptions &options) {
43     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
44     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
45     vectorLength = options.vectorLength;
46     enableSIMDIndex32 = options.enableSIMDIndex32;
47     enableVLAVectorization = options.enableVLAVectorization;
48   }
49 
runOnOperation__anon1c79d42c0111::SparsificationPass50   void runOnOperation() override {
51     auto *ctx = &getContext();
52     // Apply pre-rewriting.
53     RewritePatternSet prePatterns(ctx);
54     populateSparseTensorRewriting(prePatterns);
55     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
56     // Translate strategy flags to strategy options.
57     SparsificationOptions options(
58         sparseParallelizationStrategy(parallelization),
59         sparseVectorizationStrategy(vectorization), vectorLength,
60         enableSIMDIndex32, enableVLAVectorization);
61     // Apply sparsification and vector cleanup rewriting.
62     RewritePatternSet patterns(ctx);
63     populateSparsificationPatterns(patterns, options);
64     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
65     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
66   }
67 };
68 
69 class SparseTensorTypeConverter : public TypeConverter {
70 public:
SparseTensorTypeConverter()71   SparseTensorTypeConverter() {
72     addConversion([](Type type) { return type; });
73     addConversion(convertSparseTensorTypes);
74   }
75   // Maps each sparse tensor type to an opaque pointer.
convertSparseTensorTypes(Type type)76   static Optional<Type> convertSparseTensorTypes(Type type) {
77     if (getSparseTensorEncoding(type) != nullptr)
78       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
79     return llvm::None;
80   }
81 };
82 
83 struct SparseTensorConversionPass
84     : public SparseTensorConversionBase<SparseTensorConversionPass> {
85 
86   SparseTensorConversionPass() = default;
87   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass__anon1c79d42c0111::SparseTensorConversionPass88   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
89     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
90   }
91 
runOnOperation__anon1c79d42c0111::SparseTensorConversionPass92   void runOnOperation() override {
93     auto *ctx = &getContext();
94     RewritePatternSet patterns(ctx);
95     SparseTensorTypeConverter converter;
96     ConversionTarget target(*ctx);
97     // Everything in the sparse dialect must go!
98     target.addIllegalDialect<SparseTensorDialect>();
99     // All dynamic rules below accept new function, call, return, and various
100     // tensor and bufferization operations as legal output of the rewriting
101     // provided that all sparse tensor types have been fully rewritten.
102     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
103       return converter.isSignatureLegal(op.getFunctionType());
104     });
105     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
106       return converter.isSignatureLegal(op.getCalleeType());
107     });
108     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
109       return converter.isLegal(op.getOperandTypes());
110     });
111     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
112       return converter.isLegal(op.getOperandTypes());
113     });
114     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
115       return converter.isLegal(op.getSource().getType()) &&
116              converter.isLegal(op.getDest().getType());
117     });
118     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
119         [&](tensor::ExpandShapeOp op) {
120           return converter.isLegal(op.getSrc().getType()) &&
121                  converter.isLegal(op.getResult().getType());
122         });
123     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
124         [&](tensor::CollapseShapeOp op) {
125           return converter.isLegal(op.getSrc().getType()) &&
126                  converter.isLegal(op.getResult().getType());
127         });
128     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
129         [&](bufferization::AllocTensorOp op) {
130           return converter.isLegal(op.getType());
131         });
132     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
133         [&](bufferization::DeallocTensorOp op) {
134           return converter.isLegal(op.getTensor().getType());
135         });
136     // The following operations and dialects may be introduced by the
137     // rewriting rules, and are therefore marked as legal.
138     target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
139                       complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
140                       linalg::YieldOp, tensor::ExtractOp>();
141     target.addLegalDialect<
142         arith::ArithmeticDialect, bufferization::BufferizationDialect,
143         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
144     // Translate strategy flags to strategy options.
145     SparseTensorConversionOptions options(
146         sparseToSparseConversionStrategy(sparseToSparse));
147     // Populate with rules and apply rewriting rules.
148     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
149                                                                    converter);
150     populateCallOpTypeConversionPattern(patterns, converter);
151     populateSparseTensorConversionPatterns(converter, patterns, options);
152     if (failed(applyPartialConversion(getOperation(), target,
153                                       std::move(patterns))))
154       signalPassFailure();
155   }
156 };
157 
158 } // namespace
159 
160 SparseParallelizationStrategy
sparseParallelizationStrategy(int32_t flag)161 mlir::sparseParallelizationStrategy(int32_t flag) {
162   switch (flag) {
163   default:
164     return SparseParallelizationStrategy::kNone;
165   case 1:
166     return SparseParallelizationStrategy::kDenseOuterLoop;
167   case 2:
168     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
169   case 3:
170     return SparseParallelizationStrategy::kDenseAnyLoop;
171   case 4:
172     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
173   }
174 }
175 
sparseVectorizationStrategy(int32_t flag)176 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
177   switch (flag) {
178   default:
179     return SparseVectorizationStrategy::kNone;
180   case 1:
181     return SparseVectorizationStrategy::kDenseInnerLoop;
182   case 2:
183     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
184   }
185 }
186 
187 SparseToSparseConversionStrategy
sparseToSparseConversionStrategy(int32_t flag)188 mlir::sparseToSparseConversionStrategy(int32_t flag) {
189   switch (flag) {
190   default:
191     return SparseToSparseConversionStrategy::kAuto;
192   case 1:
193     return SparseToSparseConversionStrategy::kViaCOO;
194   case 2:
195     return SparseToSparseConversionStrategy::kDirect;
196   }
197 }
198 
createSparsificationPass()199 std::unique_ptr<Pass> mlir::createSparsificationPass() {
200   return std::make_unique<SparsificationPass>();
201 }
202 
203 std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions & options)204 mlir::createSparsificationPass(const SparsificationOptions &options) {
205   return std::make_unique<SparsificationPass>(options);
206 }
207 
createSparseTensorConversionPass()208 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
209   return std::make_unique<SparseTensorConversionPass>();
210 }
211 
createSparseTensorConversionPass(const SparseTensorConversionOptions & options)212 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
213     const SparseTensorConversionOptions &options) {
214   return std::make_unique<SparseTensorConversionPass>(options);
215 }
216