1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Passes declaration.
23 //===----------------------------------------------------------------------===//
24 
25 #define GEN_PASS_CLASSES
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // Passes implementation.
30 //===----------------------------------------------------------------------===//
31 
32 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
33 
34   SparsificationPass() = default;
35   SparsificationPass(const SparsificationPass &pass)
36       : SparsificationBase<SparsificationPass>() {}
37 
38   /// Returns parallelization strategy given on command line.
39   SparseParallelizationStrategy parallelOption() {
40     switch (parallelization) {
41     default:
42       return SparseParallelizationStrategy::kNone;
43     case 1:
44       return SparseParallelizationStrategy::kDenseOuterLoop;
45     case 2:
46       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
47     case 3:
48       return SparseParallelizationStrategy::kDenseAnyLoop;
49     case 4:
50       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
51     }
52   }
53 
54   /// Returns vectorization strategy given on command line.
55   SparseVectorizationStrategy vectorOption() {
56     switch (vectorization) {
57     default:
58       return SparseVectorizationStrategy::kNone;
59     case 1:
60       return SparseVectorizationStrategy::kDenseInnerLoop;
61     case 2:
62       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
63     }
64   }
65 
66   void runOnOperation() override {
67     auto *ctx = &getContext();
68     RewritePatternSet patterns(ctx);
69     // Translate strategy flags to strategy options.
70     SparsificationOptions options(parallelOption(), vectorOption(),
71                                   vectorLength, enableSIMDIndex32);
72     // Apply rewriting.
73     populateSparsificationPatterns(patterns, options);
74     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
75     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
76   }
77 };
78 
79 class SparseTensorTypeConverter : public TypeConverter {
80 public:
81   SparseTensorTypeConverter() {
82     addConversion([](Type type) { return type; });
83     addConversion(convertSparseTensorTypes);
84   }
85   // Maps each sparse tensor type to an opaque pointer.
86   static Optional<Type> convertSparseTensorTypes(Type type) {
87     if (getSparseTensorEncoding(type) != nullptr)
88       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
89     return llvm::None;
90   }
91 };
92 
93 struct SparseTensorConversionPass
94     : public SparseTensorConversionBase<SparseTensorConversionPass> {
95   void runOnOperation() override {
96     auto *ctx = &getContext();
97     RewritePatternSet patterns(ctx);
98     SparseTensorTypeConverter converter;
99     ConversionTarget target(*ctx);
100     target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
101                         ToTensorOp>();
102     target.addDynamicallyLegalOp<FuncOp>(
103         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
104     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
105       return converter.isSignatureLegal(op.getCalleeType());
106     });
107     target.addDynamicallyLegalOp<ReturnOp>(
108         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
109     target.addLegalOp<ConstantOp, tensor::CastOp, memref::BufferCastOp,
110                       memref::CastOp>();
111     populateFuncOpTypeConversionPattern(patterns, converter);
112     populateCallOpTypeConversionPattern(patterns, converter);
113     populateSparseTensorConversionPatterns(converter, patterns);
114     if (failed(applyPartialConversion(getOperation(), target,
115                                       std::move(patterns))))
116       signalPassFailure();
117   }
118 };
119 
120 } // end anonymous namespace
121 
122 std::unique_ptr<Pass> mlir::createSparsificationPass() {
123   return std::make_unique<SparsificationPass>();
124 }
125 
126 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
127   return std::make_unique<SparseTensorConversionPass>();
128 }
129