1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Passes declaration. 23 //===----------------------------------------------------------------------===// 24 25 #define GEN_PASS_CLASSES 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 27 28 //===----------------------------------------------------------------------===// 29 // Passes implementation. 30 //===----------------------------------------------------------------------===// 31 32 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 33 34 SparsificationPass() = default; 35 SparsificationPass(const SparsificationPass &pass) {} 36 37 Option<int32_t> parallelization{ 38 *this, "parallelization-strategy", 39 llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; 40 41 Option<int32_t> vectorization{ 42 *this, "vectorization-strategy", 43 llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; 44 45 Option<int32_t> vectorLength{ 46 *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; 47 48 Option<bool> fastOutput{*this, "fast-output", 49 llvm::cl::desc("Allows fast output buffers"), 50 llvm::cl::init(false)}; 51 52 /// Returns parallelization strategy given on command line. 53 SparseParallelizationStrategy parallelOption() { 54 switch (parallelization) { 55 default: 56 return SparseParallelizationStrategy::kNone; 57 case 1: 58 return SparseParallelizationStrategy::kDenseOuterLoop; 59 case 2: 60 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 61 case 3: 62 return SparseParallelizationStrategy::kDenseAnyLoop; 63 case 4: 64 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 65 } 66 } 67 68 /// Returns vectorization strategy given on command line. 69 SparseVectorizationStrategy vectorOption() { 70 switch (vectorization) { 71 default: 72 return SparseVectorizationStrategy::kNone; 73 case 1: 74 return SparseVectorizationStrategy::kDenseInnerLoop; 75 case 2: 76 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 77 } 78 } 79 80 void runOnOperation() override { 81 auto *ctx = &getContext(); 82 RewritePatternSet patterns(ctx); 83 // Translate strategy flags to strategy options. 84 SparsificationOptions options(parallelOption(), vectorOption(), 85 vectorLength, fastOutput); 86 // Apply rewriting. 87 populateSparsificationPatterns(patterns, options); 88 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 89 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 90 } 91 }; 92 93 class SparseTensorTypeConverter : public TypeConverter { 94 public: 95 SparseTensorTypeConverter() { 96 addConversion([](Type type) { return type; }); 97 addConversion(convertSparseTensorTypes); 98 } 99 // Maps each sparse tensor type to an opaque pointer. 100 static Optional<Type> convertSparseTensorTypes(Type type) { 101 if (getSparseTensorEncoding(type) != nullptr) 102 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 103 return llvm::None; 104 } 105 }; 106 107 struct SparseTensorConversionPass 108 : public SparseTensorConversionBase<SparseTensorConversionPass> { 109 void runOnOperation() override { 110 auto *ctx = &getContext(); 111 RewritePatternSet patterns(ctx); 112 SparseTensorTypeConverter converter; 113 ConversionTarget target(*ctx); 114 target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>(); 115 target.addDynamicallyLegalOp<FuncOp>( 116 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 117 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 118 return converter.isSignatureLegal(op.getCalleeType()); 119 }); 120 target.addDynamicallyLegalOp<ReturnOp>( 121 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 122 target.addLegalOp<ConstantOp>(); 123 populateFuncOpTypeConversionPattern(patterns, converter); 124 populateCallOpTypeConversionPattern(patterns, converter); 125 populateSparseTensorConversionPatterns(converter, patterns); 126 if (failed(applyPartialConversion(getOperation(), target, 127 std::move(patterns)))) 128 signalPassFailure(); 129 } 130 }; 131 132 } // end anonymous namespace 133 134 std::unique_ptr<Pass> mlir::createSparsificationPass() { 135 return std::make_unique<SparsificationPass>(); 136 } 137 138 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 139 return std::make_unique<SparseTensorConversionPass>(); 140 } 141