1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Passes declaration.
23 //===----------------------------------------------------------------------===//
24 
25 #define GEN_PASS_CLASSES
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // Passes implementation.
30 //===----------------------------------------------------------------------===//
31 
32 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
33 
34   SparsificationPass() = default;
35   SparsificationPass(const SparsificationPass &pass) {}
36 
37   Option<int32_t> parallelization{
38       *this, "parallelization-strategy",
39       llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
40 
41   Option<int32_t> vectorization{
42       *this, "vectorization-strategy",
43       llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
44 
45   Option<int32_t> vectorLength{
46       *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
47 
48   /// Returns parallelization strategy given on command line.
49   SparseParallelizationStrategy parallelOption() {
50     switch (parallelization) {
51     default:
52       return SparseParallelizationStrategy::kNone;
53     case 1:
54       return SparseParallelizationStrategy::kDenseOuterLoop;
55     case 2:
56       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
57     case 3:
58       return SparseParallelizationStrategy::kDenseAnyLoop;
59     case 4:
60       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
61     }
62   }
63 
64   /// Returns vectorization strategy given on command line.
65   SparseVectorizationStrategy vectorOption() {
66     switch (vectorization) {
67     default:
68       return SparseVectorizationStrategy::kNone;
69     case 1:
70       return SparseVectorizationStrategy::kDenseInnerLoop;
71     case 2:
72       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
73     }
74   }
75 
76   void runOnOperation() override {
77     auto *ctx = &getContext();
78     RewritePatternSet patterns(ctx);
79     // Translate strategy flags to strategy options.
80     SparsificationOptions options(parallelOption(), vectorOption(),
81                                   vectorLength);
82     // Apply rewriting.
83     populateSparsificationPatterns(patterns, options);
84     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
85     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
86   }
87 };
88 
89 class SparseTensorTypeConverter : public TypeConverter {
90 public:
91   SparseTensorTypeConverter() {
92     addConversion([](Type type) { return type; });
93     addConversion(convertSparseTensorTypes);
94   }
95   // Maps each sparse tensor type to an opaque pointer.
96   static Optional<Type> convertSparseTensorTypes(Type type) {
97     if (getSparseTensorEncoding(type) != nullptr)
98       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
99     return llvm::None;
100   }
101 };
102 
103 struct SparseTensorConversionPass
104     : public SparseTensorConversionBase<SparseTensorConversionPass> {
105   void runOnOperation() override {
106     auto *ctx = &getContext();
107     RewritePatternSet patterns(ctx);
108     SparseTensorTypeConverter converter;
109     ConversionTarget target(*ctx);
110     target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>();
111     target.addDynamicallyLegalOp<FuncOp>(
112         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
113     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
114       return converter.isSignatureLegal(op.getCalleeType());
115     });
116     target.addDynamicallyLegalOp<ReturnOp>(
117         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
118     target.addLegalOp<ConstantOp>();
119     target.addLegalOp<tensor::CastOp>();
120     populateFuncOpTypeConversionPattern(patterns, converter);
121     populateCallOpTypeConversionPattern(patterns, converter);
122     populateSparseTensorConversionPatterns(converter, patterns);
123     if (failed(applyPartialConversion(getOperation(), target,
124                                       std::move(patterns))))
125       signalPassFailure();
126   }
127 };
128 
129 } // end anonymous namespace
130 
131 std::unique_ptr<Pass> mlir::createSparsificationPass() {
132   return std::make_unique<SparsificationPass>();
133 }
134 
135 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
136   return std::make_unique<SparseTensorConversionPass>();
137 }
138