1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 namespace {
21 
22 //===----------------------------------------------------------------------===//
23 // Passes declaration.
24 //===----------------------------------------------------------------------===//
25 
26 #define GEN_PASS_CLASSES
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // Passes implementation.
31 //===----------------------------------------------------------------------===//
32 
33 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
34 
35   SparsificationPass() = default;
36   SparsificationPass(const SparsificationPass &pass)
37       : SparsificationBase<SparsificationPass>() {}
38 
39   /// Returns parallelization strategy given on command line.
40   SparseParallelizationStrategy parallelOption() {
41     switch (parallelization) {
42     default:
43       return SparseParallelizationStrategy::kNone;
44     case 1:
45       return SparseParallelizationStrategy::kDenseOuterLoop;
46     case 2:
47       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
48     case 3:
49       return SparseParallelizationStrategy::kDenseAnyLoop;
50     case 4:
51       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
52     }
53   }
54 
55   /// Returns vectorization strategy given on command line.
56   SparseVectorizationStrategy vectorOption() {
57     switch (vectorization) {
58     default:
59       return SparseVectorizationStrategy::kNone;
60     case 1:
61       return SparseVectorizationStrategy::kDenseInnerLoop;
62     case 2:
63       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
64     }
65   }
66 
67   void runOnOperation() override {
68     auto *ctx = &getContext();
69     RewritePatternSet patterns(ctx);
70     // Translate strategy flags to strategy options.
71     SparsificationOptions options(parallelOption(), vectorOption(),
72                                   vectorLength, enableSIMDIndex32);
73     // Apply rewriting.
74     populateSparsificationPatterns(patterns, options);
75     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
76     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77   }
78 };
79 
80 class SparseTensorTypeConverter : public TypeConverter {
81 public:
82   SparseTensorTypeConverter() {
83     addConversion([](Type type) { return type; });
84     addConversion(convertSparseTensorTypes);
85   }
86   // Maps each sparse tensor type to an opaque pointer.
87   static Optional<Type> convertSparseTensorTypes(Type type) {
88     if (getSparseTensorEncoding(type) != nullptr)
89       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
90     return llvm::None;
91   }
92 };
93 
94 struct SparseTensorConversionPass
95     : public SparseTensorConversionBase<SparseTensorConversionPass> {
96   void runOnOperation() override {
97     auto *ctx = &getContext();
98     RewritePatternSet patterns(ctx);
99     SparseTensorTypeConverter converter;
100     ConversionTarget target(*ctx);
101     // Everything in the sparse dialect must go!
102     target.addIllegalDialect<SparseTensorDialect>();
103     // All dynamic rules below accept new function, call, return, and tensor
104     // dim and cast operations as legal output of the rewriting provided that
105     // all sparse tensor types have been fully rewritten.
106     target.addDynamicallyLegalOp<FuncOp>(
107         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
108     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
109       return converter.isSignatureLegal(op.getCalleeType());
110     });
111     target.addDynamicallyLegalOp<ReturnOp>(
112         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
113     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
114       return converter.isLegal(op.getOperandTypes());
115     });
116     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
117       return converter.isLegal(op.getOperand().getType());
118     });
119     // The following operations and dialects may be introduced by the
120     // rewriting rules, and are therefore marked as legal.
121     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
122                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
123                       tensor::ExtractOp>();
124     target
125         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
126                          memref::MemRefDialect, scf::SCFDialect>();
127     // Populate with rules and apply rewriting rules.
128     populateFuncOpTypeConversionPattern(patterns, converter);
129     populateCallOpTypeConversionPattern(patterns, converter);
130     populateSparseTensorConversionPatterns(converter, patterns);
131     if (failed(applyPartialConversion(getOperation(), target,
132                                       std::move(patterns))))
133       signalPassFailure();
134   }
135 };
136 
137 } // namespace
138 
139 std::unique_ptr<Pass> mlir::createSparsificationPass() {
140   return std::make_unique<SparsificationPass>();
141 }
142 
143 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
144   return std::make_unique<SparseTensorConversionPass>();
145 }
146