1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Passes declaration.
23 //===----------------------------------------------------------------------===//
24 
25 #define GEN_PASS_CLASSES
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // Passes implementation.
30 //===----------------------------------------------------------------------===//
31 
32 struct SparsificationPass : public SparsificationBase<SparsificationPass> {
33 
34   SparsificationPass() = default;
35   SparsificationPass(const SparsificationPass &pass)
36       : SparsificationBase<SparsificationPass>() {}
37 
38   /// Returns parallelization strategy given on command line.
39   SparseParallelizationStrategy parallelOption() {
40     switch (parallelization) {
41     default:
42       return SparseParallelizationStrategy::kNone;
43     case 1:
44       return SparseParallelizationStrategy::kDenseOuterLoop;
45     case 2:
46       return SparseParallelizationStrategy::kAnyStorageOuterLoop;
47     case 3:
48       return SparseParallelizationStrategy::kDenseAnyLoop;
49     case 4:
50       return SparseParallelizationStrategy::kAnyStorageAnyLoop;
51     }
52   }
53 
54   /// Returns vectorization strategy given on command line.
55   SparseVectorizationStrategy vectorOption() {
56     switch (vectorization) {
57     default:
58       return SparseVectorizationStrategy::kNone;
59     case 1:
60       return SparseVectorizationStrategy::kDenseInnerLoop;
61     case 2:
62       return SparseVectorizationStrategy::kAnyStorageInnerLoop;
63     }
64   }
65 
66   void runOnOperation() override {
67     auto *ctx = &getContext();
68     RewritePatternSet patterns(ctx);
69     // Translate strategy flags to strategy options.
70     SparsificationOptions options(parallelOption(), vectorOption(),
71                                   vectorLength, enableSIMDIndex32);
72     // Apply rewriting.
73     populateSparsificationPatterns(patterns, options);
74     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
75     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
76   }
77 };
78 
79 class SparseTensorTypeConverter : public TypeConverter {
80 public:
81   SparseTensorTypeConverter() {
82     addConversion([](Type type) { return type; });
83     addConversion(convertSparseTensorTypes);
84   }
85   // Maps each sparse tensor type to an opaque pointer.
86   static Optional<Type> convertSparseTensorTypes(Type type) {
87     if (getSparseTensorEncoding(type) != nullptr)
88       return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
89     return llvm::None;
90   }
91 };
92 
93 struct SparseTensorConversionPass
94     : public SparseTensorConversionBase<SparseTensorConversionPass> {
95   void runOnOperation() override {
96     auto *ctx = &getContext();
97     RewritePatternSet patterns(ctx);
98     SparseTensorTypeConverter converter;
99     ConversionTarget target(*ctx);
100     // Everything in the sparse dialect must go!
101     target.addIllegalDialect<SparseTensorDialect>();
102     // All dynamic rules below accept new function, call, return, and tensor
103     // dim and cast operations as legal output of the rewriting provided that
104     // all sparse tensor types have been fully rewritten.
105     target.addDynamicallyLegalOp<FuncOp>(
106         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
107     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
108       return converter.isSignatureLegal(op.getCalleeType());
109     });
110     target.addDynamicallyLegalOp<ReturnOp>(
111         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
112     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
113       return converter.isLegal(op.getOperandTypes());
114     });
115     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
116       return converter.isLegal(op.getOperand().getType());
117     });
118     // The following operations and dialects may be introduced by the
119     // rewriting rules, and are therefore marked as legal.
120     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
121                       arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
122                       tensor::ExtractOp>();
123     target.addLegalDialect<LLVM::LLVMDialect, memref::MemRefDialect,
124                            scf::SCFDialect>();
125     // Populate with rules and apply rewriting rules.
126     populateFuncOpTypeConversionPattern(patterns, converter);
127     populateCallOpTypeConversionPattern(patterns, converter);
128     populateSparseTensorConversionPatterns(converter, patterns);
129     if (failed(applyPartialConversion(getOperation(), target,
130                                       std::move(patterns))))
131       signalPassFailure();
132   }
133 };
134 
135 } // end anonymous namespace
136 
137 std::unique_ptr<Pass> mlir::createSparsificationPass() {
138   return std::make_unique<SparsificationPass>();
139 }
140 
141 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
142   return std::make_unique<SparseTensorConversionPass>();
143 }
144