1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Passes declaration. 23 //===----------------------------------------------------------------------===// 24 25 #define GEN_PASS_CLASSES 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 27 28 //===----------------------------------------------------------------------===// 29 // Passes implementation. 30 //===----------------------------------------------------------------------===// 31 32 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 33 34 SparsificationPass() = default; 35 SparsificationPass(const SparsificationPass &pass) {} 36 37 Option<int32_t> parallelization{ 38 *this, "parallelization-strategy", 39 llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; 40 41 Option<int32_t> vectorization{ 42 *this, "vectorization-strategy", 43 llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; 44 45 Option<int32_t> vectorLength{ 46 *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; 47 48 /// Returns parallelization strategy given on command line. 49 SparseParallelizationStrategy parallelOption() { 50 switch (parallelization) { 51 default: 52 return SparseParallelizationStrategy::kNone; 53 case 1: 54 return SparseParallelizationStrategy::kDenseOuterLoop; 55 case 2: 56 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 57 case 3: 58 return SparseParallelizationStrategy::kDenseAnyLoop; 59 case 4: 60 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 61 } 62 } 63 64 /// Returns vectorization strategy given on command line. 65 SparseVectorizationStrategy vectorOption() { 66 switch (vectorization) { 67 default: 68 return SparseVectorizationStrategy::kNone; 69 case 1: 70 return SparseVectorizationStrategy::kDenseInnerLoop; 71 case 2: 72 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 73 } 74 } 75 76 void runOnOperation() override { 77 auto *ctx = &getContext(); 78 RewritePatternSet patterns(ctx); 79 // Translate strategy flags to strategy options. 80 SparsificationOptions options(parallelOption(), vectorOption(), 81 vectorLength); 82 // Apply rewriting. 83 populateSparsificationPatterns(patterns, options); 84 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 85 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 86 } 87 }; 88 89 class SparseTensorTypeConverter : public TypeConverter { 90 public: 91 SparseTensorTypeConverter() { 92 addConversion([](Type type) { return type; }); 93 addConversion(convertSparseTensorTypes); 94 } 95 // Maps each sparse tensor type to an opaque pointer. 96 static Optional<Type> convertSparseTensorTypes(Type type) { 97 if (getSparseTensorEncoding(type) != nullptr) 98 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 99 return llvm::None; 100 } 101 }; 102 103 struct SparseTensorConversionPass 104 : public SparseTensorConversionBase<SparseTensorConversionPass> { 105 void runOnOperation() override { 106 auto *ctx = &getContext(); 107 RewritePatternSet patterns(ctx); 108 SparseTensorTypeConverter converter; 109 ConversionTarget target(*ctx); 110 target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>(); 111 target.addDynamicallyLegalOp<FuncOp>( 112 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 113 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 114 return converter.isSignatureLegal(op.getCalleeType()); 115 }); 116 target.addDynamicallyLegalOp<ReturnOp>( 117 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 118 target.addLegalOp<ConstantOp>(); 119 target.addLegalOp<tensor::CastOp>(); 120 populateFuncOpTypeConversionPattern(patterns, converter); 121 populateCallOpTypeConversionPattern(patterns, converter); 122 populateSparseTensorConversionPatterns(converter, patterns); 123 if (failed(applyPartialConversion(getOperation(), target, 124 std::move(patterns)))) 125 signalPassFailure(); 126 } 127 }; 128 129 } // end anonymous namespace 130 131 std::unique_ptr<Pass> mlir::createSparsificationPass() { 132 return std::make_unique<SparsificationPass>(); 133 } 134 135 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 136 return std::make_unique<SparseTensorConversionPass>(); 137 } 138