1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 
19 //===----------------------------------------------------------------------===//
20 // Passes declaration.
21 //===----------------------------------------------------------------------===//
22 
23 #define GEN_PASS_CLASSES
24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // Passes implementation.
28 //===----------------------------------------------------------------------===//
29 
30 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
31 
32   SparsificationPass() = default;
33   SparsificationPass(const SparsificationPass &pass) {}
34 
35   Option<int32_t> parallelization{
36       *this, "parallelization-strategy",
37       llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
38 
39   Option<int32_t> vectorization{
40       *this, "vectorization-strategy",
41       llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
42 
43   Option<int32_t> vectorLength{
44       *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
45 
46   Option<int32_t> ptrType{*this, "ptr-type",
47                           llvm::cl::desc("Set the pointer type"),
48                           llvm::cl::init(0)};
49 
50   Option<int32_t> indType{*this, "ind-type",
51                           llvm::cl::desc("Set the index type"),
52                           llvm::cl::init(0)};
53 
54   Option<bool> fastOutput{*this, "fast-output",
55                           llvm::cl::desc("Allows fast output buffers"),
56                           llvm::cl::init(false)};
57 
58   /// Returns parallelization strategy given on command line.
59   SparseParallelizationStrategy parallelOption() {
60     switch (parallelization) {
61     default:
62       return SparseParallelizationStrategy::kNone;
63     case 1:
64       return SparseParallelizationStrategy::kDenseOuterLoop;
65     case 2:
66       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
67     case 3:
68       return SparseParallelizationStrategy::kDenseAnyLoop;
69     case 4:
70       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
71     }
72   }
73 
74   /// Returns vectorization strategy given on command line.
75   SparseVectorizationStrategy vectorOption() {
76     switch (vectorization) {
77     default:
78       return SparseVectorizationStrategy::kNone;
79     case 1:
80       return SparseVectorizationStrategy::kDenseInnerLoop;
81     case 2:
82       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
83     }
84   }
85 
86   /// Returns the requested integer type.
87   SparseIntType typeOption(int32_t option) {
88     switch (option) {
89     default:
90       return SparseIntType::kNative;
91     case 1:
92       return SparseIntType::kI64;
93     case 2:
94       return SparseIntType::kI32;
95     case 3:
96       return SparseIntType::kI16;
97     case 4:
98       return SparseIntType::kI8;
99     }
100   }
101 
102   void runOnOperation() override {
103     auto *ctx = &getContext();
104     RewritePatternSet patterns(ctx);
105     // Translate strategy flags to strategy options.
106     SparsificationOptions options(parallelOption(), vectorOption(),
107                                   vectorLength, typeOption(ptrType),
108                                   typeOption(indType), fastOutput);
109     // Apply rewriting.
110     populateSparsificationPatterns(patterns, options);
111     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
112     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
113   }
114 };
115 
116 struct SparseTensorConversionPass
117     : public SparseTensorConversionBase<SparseTensorConversionPass> {
118   void runOnOperation() override {
119     auto *ctx = &getContext();
120     RewritePatternSet conversionPatterns(ctx);
121     ConversionTarget target(*ctx);
122     target
123         .addIllegalOp<sparse_tensor::FromPointerOp, sparse_tensor::ToPointersOp,
124                       sparse_tensor::ToIndicesOp, sparse_tensor::ToValuesOp>();
125     target.addLegalOp<CallOp>();
126     populateSparseTensorConversionPatterns(conversionPatterns);
127     if (failed(applyPartialConversion(getOperation(), target,
128                                       std::move(conversionPatterns))))
129       signalPassFailure();
130   }
131 };
132 
133 } // end anonymous namespace
134 
135 std::unique_ptr<Pass> mlir::createSparsificationPass() {
136   return std::make_unique<SparsificationPass>();
137 }
138 
139 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
140   return std::make_unique<SparseTensorConversionPass>();
141 }
142