1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Passes declaration.
23 //===----------------------------------------------------------------------===//
24 
25 #define GEN_PASS_CLASSES
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // Passes implementation.
30 //===----------------------------------------------------------------------===//
31 
32 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
33 
34   SparsificationPass() = default;
35   SparsificationPass(const SparsificationPass &pass) {}
36 
37   Option<int32_t> parallelization{
38       *this, "parallelization-strategy",
39       llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
40 
41   Option<int32_t> vectorization{
42       *this, "vectorization-strategy",
43       llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
44 
45   Option<int32_t> vectorLength{
46       *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
47 
48   Option<bool> fastOutput{*this, "fast-output",
49                           llvm::cl::desc("Allows fast output buffers"),
50                           llvm::cl::init(false)};
51 
52   /// Returns parallelization strategy given on command line.
53   SparseParallelizationStrategy parallelOption() {
54     switch (parallelization) {
55     default:
56       return SparseParallelizationStrategy::kNone;
57     case 1:
58       return SparseParallelizationStrategy::kDenseOuterLoop;
59     case 2:
60       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
61     case 3:
62       return SparseParallelizationStrategy::kDenseAnyLoop;
63     case 4:
64       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
65     }
66   }
67 
68   /// Returns vectorization strategy given on command line.
69   SparseVectorizationStrategy vectorOption() {
70     switch (vectorization) {
71     default:
72       return SparseVectorizationStrategy::kNone;
73     case 1:
74       return SparseVectorizationStrategy::kDenseInnerLoop;
75     case 2:
76       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
77     }
78   }
79 
80   void runOnOperation() override {
81     auto *ctx = &getContext();
82     RewritePatternSet patterns(ctx);
83     // Translate strategy flags to strategy options.
84     SparsificationOptions options(parallelOption(), vectorOption(),
85                                   vectorLength, fastOutput);
86     // Apply rewriting.
87     populateSparsificationPatterns(patterns, options);
88     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
89     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
90   }
91 };
92 
93 class SparseTensorTypeConverter : public TypeConverter {
94 public:
95   SparseTensorTypeConverter() {
96     addConversion([](Type type) { return type; });
97     addConversion(convertSparseTensorTypes);
98   }
99   // Maps each sparse tensor type to an opaque pointer.
100   static Optional<Type> convertSparseTensorTypes(Type type) {
101     if (getSparseTensorEncoding(type) != nullptr)
102       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
103     return llvm::None;
104   }
105 };
106 
107 struct SparseTensorConversionPass
108     : public SparseTensorConversionBase<SparseTensorConversionPass> {
109   void runOnOperation() override {
110     auto *ctx = &getContext();
111     RewritePatternSet patterns(ctx);
112     SparseTensorTypeConverter converter;
113     ConversionTarget target(*ctx);
114     target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>();
115     target.addDynamicallyLegalOp<FuncOp>(
116         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
117     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
118       return converter.isSignatureLegal(op.getCalleeType());
119     });
120     target.addDynamicallyLegalOp<ReturnOp>(
121         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
122     target.addLegalOp<ConstantOp>();
123     populateFuncOpTypeConversionPattern(patterns, converter);
124     populateCallOpTypeConversionPattern(patterns, converter);
125     populateSparseTensorConversionPatterns(converter, patterns);
126     if (failed(applyPartialConversion(getOperation(), target,
127                                       std::move(patterns))))
128       signalPassFailure();
129   }
130 };
131 
132 } // end anonymous namespace
133 
134 std::unique_ptr<Pass> mlir::createSparsificationPass() {
135   return std::make_unique<SparsificationPass>();
136 }
137 
138 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
139   return std::make_unique<SparseTensorConversionPass>();
140 }
141