1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::sparse_tensor;
23 
24 namespace {
25 
26 //===----------------------------------------------------------------------===//
27 // Passes declaration.
28 //===----------------------------------------------------------------------===//
29 
30 #define GEN_PASS_CLASSES
31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Passes implementation.
35 //===----------------------------------------------------------------------===//
36 
37 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
38 
39   SparsificationPass() = default;
40   SparsificationPass(const SparsificationPass &pass) = default;
41   SparsificationPass(const SparsificationOptions &options) {
42     parallelization = options.parallelizationStrategy;
43     vectorization = options.vectorizationStrategy;
44     vectorLength = options.vectorLength;
45     enableSIMDIndex32 = options.enableSIMDIndex32;
46     enableVLAVectorization = options.enableVLAVectorization;
47   }
48 
49   void runOnOperation() override {
50     auto *ctx = &getContext();
51     RewritePatternSet patterns(ctx);
52     // Translate strategy flags to strategy options.
53     SparsificationOptions options(parallelization, vectorization, vectorLength,
54                                   enableSIMDIndex32, enableVLAVectorization);
55     // Apply rewriting.
56     populateSparsificationPatterns(patterns, options);
57     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
58     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
59   }
60 };
61 
62 class SparseTensorTypeConverter : public TypeConverter {
63 public:
64   SparseTensorTypeConverter() {
65     addConversion([](Type type) { return type; });
66     addConversion(convertSparseTensorTypes);
67   }
68   // Maps each sparse tensor type to an opaque pointer.
69   static Optional<Type> convertSparseTensorTypes(Type type) {
70     if (getSparseTensorEncoding(type) != nullptr)
71       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
72     return llvm::None;
73   }
74 };
75 
76 struct SparseTensorConversionPass
77     : public SparseTensorConversionBase<SparseTensorConversionPass> {
78 
79   SparseTensorConversionPass() = default;
80   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
81   SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
82     sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
83   }
84 
85   void runOnOperation() override {
86     auto *ctx = &getContext();
87     RewritePatternSet patterns(ctx);
88     SparseTensorTypeConverter converter;
89     ConversionTarget target(*ctx);
90     // Everything in the sparse dialect must go!
91     target.addIllegalDialect<SparseTensorDialect>();
92     // All dynamic rules below accept new function, call, return, and tensor
93     // dim and cast operations as legal output of the rewriting provided that
94     // all sparse tensor types have been fully rewritten.
95     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
96       return converter.isSignatureLegal(op.getFunctionType());
97     });
98     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
99       return converter.isSignatureLegal(op.getCalleeType());
100     });
101     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
102       return converter.isLegal(op.getOperandTypes());
103     });
104     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
105       return converter.isLegal(op.getOperandTypes());
106     });
107     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
108       return converter.isLegal(op.getOperand().getType());
109     });
110     // The following operations and dialects may be introduced by the
111     // rewriting rules, and are therefore marked as legal.
112     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
113                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
114                       tensor::ExtractOp>();
115     target
116         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
117                          memref::MemRefDialect, scf::SCFDialect>();
118     // Translate strategy flags to strategy options.
119     SparseTensorConversionOptions options(
120         sparseToSparseConversionStrategy(sparseToSparse));
121     // Populate with rules and apply rewriting rules.
122     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
123                                                                    converter);
124     populateCallOpTypeConversionPattern(patterns, converter);
125     populateSparseTensorConversionPatterns(converter, patterns, options);
126     if (failed(applyPartialConversion(getOperation(), target,
127                                       std::move(patterns))))
128       signalPassFailure();
129   }
130 };
131 
132 } // namespace
133 
134 SparseToSparseConversionStrategy
135 mlir::sparseToSparseConversionStrategy(int32_t flag) {
136   switch (flag) {
137   default:
138     return SparseToSparseConversionStrategy::kAuto;
139   case 1:
140     return SparseToSparseConversionStrategy::kViaCOO;
141   case 2:
142     return SparseToSparseConversionStrategy::kDirect;
143   }
144 }
145 
146 std::unique_ptr<Pass> mlir::createSparsificationPass() {
147   return std::make_unique<SparsificationPass>();
148 }
149 
150 std::unique_ptr<Pass>
151 mlir::createSparsificationPass(const SparsificationOptions &options) {
152   return std::make_unique<SparsificationPass>(options);
153 }
154 
155 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
156   return std::make_unique<SparseTensorConversionPass>();
157 }
158 
159 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
160     const SparseTensorConversionOptions &options) {
161   return std::make_unique<SparseTensorConversionPass>(options);
162 }
163