1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::sparse_tensor; 20 21 namespace { 22 23 //===----------------------------------------------------------------------===// 24 // Passes declaration. 25 //===----------------------------------------------------------------------===// 26 27 #define GEN_PASS_CLASSES 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 29 30 //===----------------------------------------------------------------------===// 31 // Passes implementation. 32 //===----------------------------------------------------------------------===// 33 34 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 35 36 SparsificationPass() = default; 37 SparsificationPass(const SparsificationPass &pass) = default; 38 SparsificationPass(const SparsificationOptions &options) { 39 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 40 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 41 vectorLength = options.vectorLength; 42 enableSIMDIndex32 = options.enableSIMDIndex32; 43 enableVLAVectorization = options.enableVLAVectorization; 44 } 45 46 void runOnOperation() override { 47 auto *ctx = &getContext(); 48 RewritePatternSet patterns(ctx); 49 // Translate strategy flags to strategy options. 50 SparsificationOptions options( 51 sparseParallelizationStrategy(parallelization), 52 sparseVectorizationStrategy(vectorization), vectorLength, 53 enableSIMDIndex32, enableVLAVectorization); 54 // Apply rewriting. 55 populateSparsificationPatterns(patterns, options); 56 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 57 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 58 } 59 }; 60 61 class SparseTensorTypeConverter : public TypeConverter { 62 public: 63 SparseTensorTypeConverter() { 64 addConversion([](Type type) { return type; }); 65 addConversion(convertSparseTensorTypes); 66 } 67 // Maps each sparse tensor type to an opaque pointer. 68 static Optional<Type> convertSparseTensorTypes(Type type) { 69 if (getSparseTensorEncoding(type) != nullptr) 70 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 71 return llvm::None; 72 } 73 }; 74 75 struct SparseTensorConversionPass 76 : public SparseTensorConversionBase<SparseTensorConversionPass> { 77 78 SparseTensorConversionPass() = default; 79 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 80 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 81 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 82 } 83 84 void runOnOperation() override { 85 auto *ctx = &getContext(); 86 RewritePatternSet patterns(ctx); 87 SparseTensorTypeConverter converter; 88 ConversionTarget target(*ctx); 89 // Everything in the sparse dialect must go! 90 target.addIllegalDialect<SparseTensorDialect>(); 91 // All dynamic rules below accept new function, call, return, and tensor 92 // dim and cast operations as legal output of the rewriting provided that 93 // all sparse tensor types have been fully rewritten. 94 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 95 return converter.isSignatureLegal(op.getFunctionType()); 96 }); 97 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 98 return converter.isSignatureLegal(op.getCalleeType()); 99 }); 100 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 101 return converter.isLegal(op.getOperandTypes()); 102 }); 103 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 104 return converter.isLegal(op.getOperandTypes()); 105 }); 106 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 107 return converter.isLegal(op.getOperand().getType()); 108 }); 109 // The following operations and dialects may be introduced by the 110 // rewriting rules, and are therefore marked as legal. 111 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 112 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 113 tensor::ExtractOp>(); 114 target 115 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 116 memref::MemRefDialect, scf::SCFDialect>(); 117 // Translate strategy flags to strategy options. 118 SparseTensorConversionOptions options( 119 sparseToSparseConversionStrategy(sparseToSparse)); 120 // Populate with rules and apply rewriting rules. 121 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 122 converter); 123 populateCallOpTypeConversionPattern(patterns, converter); 124 populateSparseTensorConversionPatterns(converter, patterns, options); 125 if (failed(applyPartialConversion(getOperation(), target, 126 std::move(patterns)))) 127 signalPassFailure(); 128 } 129 }; 130 131 } // namespace 132 133 SparseParallelizationStrategy 134 mlir::sparseParallelizationStrategy(int32_t flag) { 135 switch (flag) { 136 default: 137 return SparseParallelizationStrategy::kNone; 138 case 1: 139 return SparseParallelizationStrategy::kDenseOuterLoop; 140 case 2: 141 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 142 case 3: 143 return SparseParallelizationStrategy::kDenseAnyLoop; 144 case 4: 145 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 146 } 147 } 148 149 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 150 switch (flag) { 151 default: 152 return SparseVectorizationStrategy::kNone; 153 case 1: 154 return SparseVectorizationStrategy::kDenseInnerLoop; 155 case 2: 156 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 157 } 158 } 159 160 SparseToSparseConversionStrategy 161 mlir::sparseToSparseConversionStrategy(int32_t flag) { 162 switch (flag) { 163 default: 164 return SparseToSparseConversionStrategy::kAuto; 165 case 1: 166 return SparseToSparseConversionStrategy::kViaCOO; 167 case 2: 168 return SparseToSparseConversionStrategy::kDirect; 169 } 170 } 171 172 std::unique_ptr<Pass> mlir::createSparsificationPass() { 173 return std::make_unique<SparsificationPass>(); 174 } 175 176 std::unique_ptr<Pass> 177 mlir::createSparsificationPass(const SparsificationOptions &options) { 178 return std::make_unique<SparsificationPass>(options); 179 } 180 181 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 182 return std::make_unique<SparseTensorConversionPass>(); 183 } 184 185 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 186 const SparseTensorConversionOptions &options) { 187 return std::make_unique<SparseTensorConversionPass>(options); 188 } 189