1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::sparse_tensor;
23 
24 namespace {
25 
26 //===----------------------------------------------------------------------===//
27 // Passes declaration.
28 //===----------------------------------------------------------------------===//
29 
30 #define GEN_PASS_CLASSES
31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Passes implementation.
35 //===----------------------------------------------------------------------===//
36 
37 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
38 
39   SparsificationPass() = default;
40   SparsificationPass(const SparsificationPass &pass) = default;
41   SparsificationPass(const SparsificationOptions &options) {
42     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
43     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
44     vectorLength = options.vectorLength;
45     enableSIMDIndex32 = options.enableSIMDIndex32;
46     enableVLAVectorization = options.enableVLAVectorization;
47   }
48 
49   void runOnOperation() override {
50     auto *ctx = &getContext();
51     RewritePatternSet patterns(ctx);
52     // Translate strategy flags to strategy options.
53     SparsificationOptions options(
54         sparseParallelizationStrategy(parallelization),
55         sparseVectorizationStrategy(vectorization), vectorLength,
56         enableSIMDIndex32, enableVLAVectorization);
57     // Apply rewriting.
58     populateSparsificationPatterns(patterns, options);
59     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
60     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
61   }
62 };
63 
64 class SparseTensorTypeConverter : public TypeConverter {
65 public:
66   SparseTensorTypeConverter() {
67     addConversion([](Type type) { return type; });
68     addConversion(convertSparseTensorTypes);
69   }
70   // Maps each sparse tensor type to an opaque pointer.
71   static Optional<Type> convertSparseTensorTypes(Type type) {
72     if (getSparseTensorEncoding(type) != nullptr)
73       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
74     return llvm::None;
75   }
76 };
77 
78 struct SparseTensorConversionPass
79     : public SparseTensorConversionBase<SparseTensorConversionPass> {
80 
81   SparseTensorConversionPass() = default;
82   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
83   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
84     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
85   }
86 
87   void runOnOperation() override {
88     auto *ctx = &getContext();
89     RewritePatternSet patterns(ctx);
90     SparseTensorTypeConverter converter;
91     ConversionTarget target(*ctx);
92     // Everything in the sparse dialect must go!
93     target.addIllegalDialect<SparseTensorDialect>();
94     // All dynamic rules below accept new function, call, return, and tensor
95     // dim and cast operations as legal output of the rewriting provided that
96     // all sparse tensor types have been fully rewritten.
97     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
98       return converter.isSignatureLegal(op.getFunctionType());
99     });
100     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
101       return converter.isSignatureLegal(op.getCalleeType());
102     });
103     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
104       return converter.isLegal(op.getOperandTypes());
105     });
106     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
107       return converter.isLegal(op.getOperandTypes());
108     });
109     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
110       return converter.isLegal(op.getOperand().getType());
111     });
112     // The following operations and dialects may be introduced by the
113     // rewriting rules, and are therefore marked as legal.
114     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
115                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
116                       tensor::ExtractOp>();
117     target
118         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
119                          memref::MemRefDialect, scf::SCFDialect>();
120     // Translate strategy flags to strategy options.
121     SparseTensorConversionOptions options(
122         sparseToSparseConversionStrategy(sparseToSparse));
123     // Populate with rules and apply rewriting rules.
124     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
125                                                                    converter);
126     populateCallOpTypeConversionPattern(patterns, converter);
127     populateSparseTensorConversionPatterns(converter, patterns, options);
128     if (failed(applyPartialConversion(getOperation(), target,
129                                       std::move(patterns))))
130       signalPassFailure();
131   }
132 };
133 
134 } // namespace
135 
136 SparseParallelizationStrategy
137 mlir::sparseParallelizationStrategy(int32_t flag) {
138   switch (flag) {
139   default:
140     return SparseParallelizationStrategy::kNone;
141   case 1:
142     return SparseParallelizationStrategy::kDenseOuterLoop;
143   case 2:
144     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
145   case 3:
146     return SparseParallelizationStrategy::kDenseAnyLoop;
147   case 4:
148     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
149   }
150 }
151 
152 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
153   switch (flag) {
154   default:
155     return SparseVectorizationStrategy::kNone;
156   case 1:
157     return SparseVectorizationStrategy::kDenseInnerLoop;
158   case 2:
159     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
160   }
161 }
162 
163 SparseToSparseConversionStrategy
164 mlir::sparseToSparseConversionStrategy(int32_t flag) {
165   switch (flag) {
166   default:
167     return SparseToSparseConversionStrategy::kAuto;
168   case 1:
169     return SparseToSparseConversionStrategy::kViaCOO;
170   case 2:
171     return SparseToSparseConversionStrategy::kDirect;
172   }
173 }
174 
175 std::unique_ptr<Pass> mlir::createSparsificationPass() {
176   return std::make_unique<SparsificationPass>();
177 }
178 
179 std::unique_ptr<Pass>
180 mlir::createSparsificationPass(const SparsificationOptions &options) {
181   return std::make_unique<SparsificationPass>(options);
182 }
183 
184 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
185   return std::make_unique<SparseTensorConversionPass>();
186 }
187 
188 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
189     const SparseTensorConversionOptions &options) {
190   return std::make_unique<SparseTensorConversionPass>(options);
191 }
192