1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 namespace {
21 
22 //===----------------------------------------------------------------------===//
23 // Passes declaration.
24 //===----------------------------------------------------------------------===//
25 
26 #define GEN_PASS_CLASSES
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // Passes implementation.
31 //===----------------------------------------------------------------------===//
32 
33 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
34 
35   SparsificationPass() = default;
36   SparsificationPass(const SparsificationPass &pass) = default;
37   SparsificationPass(const SparsificationOptions &options) {
38     parallelization = static_cast<int32_t>(options.parallelizationStrategy);
39     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
40     vectorLength = options.vectorLength;
41     enableSIMDIndex32 = options.enableSIMDIndex32;
42   }
43 
44   void runOnOperation() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     // Translate strategy flags to strategy options.
48     SparsificationOptions options(
49         sparseParallelizationStrategy(parallelization),
50         sparseVectorizationStrategy(vectorization), vectorLength,
51         enableSIMDIndex32);
52     // Apply rewriting.
53     populateSparsificationPatterns(patterns, options);
54     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
55     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
56   }
57 };
58 
59 class SparseTensorTypeConverter : public TypeConverter {
60 public:
61   SparseTensorTypeConverter() {
62     addConversion([](Type type) { return type; });
63     addConversion(convertSparseTensorTypes);
64   }
65   // Maps each sparse tensor type to an opaque pointer.
66   static Optional<Type> convertSparseTensorTypes(Type type) {
67     if (getSparseTensorEncoding(type) != nullptr)
68       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
69     return llvm::None;
70   }
71 };
72 
73 struct SparseTensorConversionPass
74     : public SparseTensorConversionBase<SparseTensorConversionPass> {
75   void runOnOperation() override {
76     auto *ctx = &getContext();
77     RewritePatternSet patterns(ctx);
78     SparseTensorTypeConverter converter;
79     ConversionTarget target(*ctx);
80     // Everything in the sparse dialect must go!
81     target.addIllegalDialect<SparseTensorDialect>();
82     // All dynamic rules below accept new function, call, return, and tensor
83     // dim and cast operations as legal output of the rewriting provided that
84     // all sparse tensor types have been fully rewritten.
85     target.addDynamicallyLegalOp<FuncOp>(
86         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
87     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
88       return converter.isSignatureLegal(op.getCalleeType());
89     });
90     target.addDynamicallyLegalOp<ReturnOp>(
91         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
92     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
93       return converter.isLegal(op.getOperandTypes());
94     });
95     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
96       return converter.isLegal(op.getOperand().getType());
97     });
98     // The following operations and dialects may be introduced by the
99     // rewriting rules, and are therefore marked as legal.
100     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
101                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
102                       tensor::ExtractOp>();
103     target
104         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
105                          memref::MemRefDialect, scf::SCFDialect>();
106     // Populate with rules and apply rewriting rules.
107     populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
108                                                              converter);
109     populateCallOpTypeConversionPattern(patterns, converter);
110     populateSparseTensorConversionPatterns(converter, patterns);
111     if (failed(applyPartialConversion(getOperation(), target,
112                                       std::move(patterns))))
113       signalPassFailure();
114   }
115 };
116 
117 } // namespace
118 
119 SparseParallelizationStrategy
120 mlir::sparseParallelizationStrategy(int32_t flag) {
121   switch (flag) {
122   default:
123     return SparseParallelizationStrategy::kNone;
124   case 1:
125     return SparseParallelizationStrategy::kDenseOuterLoop;
126   case 2:
127     return SparseParallelizationStrategy::kAnyStorageOuterLoop;
128   case 3:
129     return SparseParallelizationStrategy::kDenseAnyLoop;
130   case 4:
131     return SparseParallelizationStrategy::kAnyStorageAnyLoop;
132   }
133 }
134 
135 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
136   switch (flag) {
137   default:
138     return SparseVectorizationStrategy::kNone;
139   case 1:
140     return SparseVectorizationStrategy::kDenseInnerLoop;
141   case 2:
142     return SparseVectorizationStrategy::kAnyStorageInnerLoop;
143   }
144 }
145 
146 std::unique_ptr<Pass> mlir::createSparsificationPass() {
147   return std::make_unique<SparsificationPass>();
148 }
149 
150 std::unique_ptr<Pass>
151 mlir::createSparsificationPass(const SparsificationOptions &options) {
152   return std::make_unique<SparsificationPass>(options);
153 }
154 
155 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
156   return std::make_unique<SparseTensorConversionPass>();
157 }
158