1 //===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14 15 using namespace mlir; 16 17 namespace { 18 19 //===----------------------------------------------------------------------===// 20 // Passes declaration. 21 //===----------------------------------------------------------------------===// 22 23 #define GEN_PASS_CLASSES 24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 25 26 //===----------------------------------------------------------------------===// 27 // Passes implementation. 28 //===----------------------------------------------------------------------===// 29 30 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 31 32 SparsificationPass() = default; 33 SparsificationPass(const SparsificationPass &pass) {} 34 35 Option<int32_t> parallelization{ 36 *this, "parallelization-strategy", 37 llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; 38 39 Option<int32_t> vectorization{ 40 *this, "vectorization-strategy", 41 llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; 42 43 Option<int32_t> vectorLength{ 44 *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; 45 46 Option<int32_t> ptrType{*this, "ptr-type", 47 llvm::cl::desc("Set the pointer type"), 48 llvm::cl::init(0)}; 49 50 Option<int32_t> indType{*this, "ind-type", 51 llvm::cl::desc("Set the index type"), 52 llvm::cl::init(0)}; 53 54 Option<bool> fastOutput{*this, "fast-output", 55 llvm::cl::desc("Allows fast output buffers"), 56 llvm::cl::init(false)}; 57 58 /// Returns parallelization strategy given on command line. 59 SparseParallelizationStrategy parallelOption() { 60 switch (parallelization) { 61 default: 62 return SparseParallelizationStrategy::kNone; 63 case 1: 64 return SparseParallelizationStrategy::kDenseOuterLoop; 65 case 2: 66 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 67 case 3: 68 return SparseParallelizationStrategy::kDenseAnyLoop; 69 case 4: 70 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 71 } 72 } 73 74 /// Returns vectorization strategy given on command line. 75 SparseVectorizationStrategy vectorOption() { 76 switch (vectorization) { 77 default: 78 return SparseVectorizationStrategy::kNone; 79 case 1: 80 return SparseVectorizationStrategy::kDenseInnerLoop; 81 case 2: 82 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 83 } 84 } 85 86 /// Returns the requested integer type. 87 SparseIntType typeOption(int32_t option) { 88 switch (option) { 89 default: 90 return SparseIntType::kNative; 91 case 1: 92 return SparseIntType::kI64; 93 case 2: 94 return SparseIntType::kI32; 95 case 3: 96 return SparseIntType::kI16; 97 case 4: 98 return SparseIntType::kI8; 99 } 100 } 101 102 void runOnOperation() override { 103 auto *ctx = &getContext(); 104 RewritePatternSet patterns(ctx); 105 // Translate strategy flags to strategy options. 106 SparsificationOptions options(parallelOption(), vectorOption(), 107 vectorLength, typeOption(ptrType), 108 typeOption(indType), fastOutput); 109 // Apply rewriting. 110 populateSparsificationPatterns(patterns, options); 111 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 112 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 113 } 114 }; 115 116 struct SparseTensorConversionPass 117 : public SparseTensorConversionBase<SparseTensorConversionPass> { 118 void runOnOperation() override { 119 auto *ctx = &getContext(); 120 RewritePatternSet conversionPatterns(ctx); 121 ConversionTarget target(*ctx); 122 target 123 .addIllegalOp<sparse_tensor::FromPointerOp, sparse_tensor::ToPointersOp, 124 sparse_tensor::ToIndicesOp, sparse_tensor::ToValuesOp>(); 125 target.addLegalOp<CallOp>(); 126 populateSparseTensorConversionPatterns(conversionPatterns); 127 if (failed(applyPartialConversion(getOperation(), target, 128 std::move(conversionPatterns)))) 129 signalPassFailure(); 130 } 131 }; 132 133 } // end anonymous namespace 134 135 std::unique_ptr<Pass> mlir::createSparsificationPass() { 136 return std::make_unique<SparsificationPass>(); 137 } 138 139 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 140 return std::make_unique<SparseTensorConversionPass>(); 141 } 142