1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::sparse_tensor; 20 21 namespace { 22 23 //===----------------------------------------------------------------------===// 24 // Passes declaration. 25 //===----------------------------------------------------------------------===// 26 27 #define GEN_PASS_CLASSES 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 29 30 //===----------------------------------------------------------------------===// 31 // Passes implementation. 32 //===----------------------------------------------------------------------===// 33 34 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 35 36 SparsificationPass() = default; 37 SparsificationPass(const SparsificationPass &pass) = default; 38 SparsificationPass(const SparsificationOptions &options) { 39 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 40 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 41 vectorLength = options.vectorLength; 42 enableSIMDIndex32 = options.enableSIMDIndex32; 43 } 44 45 void runOnOperation() override { 46 auto *ctx = &getContext(); 47 RewritePatternSet patterns(ctx); 48 // Translate strategy flags to strategy options. 49 SparsificationOptions options( 50 sparseParallelizationStrategy(parallelization), 51 sparseVectorizationStrategy(vectorization), vectorLength, 52 enableSIMDIndex32); 53 // Apply rewriting. 54 populateSparsificationPatterns(patterns, options); 55 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 57 } 58 }; 59 60 class SparseTensorTypeConverter : public TypeConverter { 61 public: 62 SparseTensorTypeConverter() { 63 addConversion([](Type type) { return type; }); 64 addConversion(convertSparseTensorTypes); 65 } 66 // Maps each sparse tensor type to an opaque pointer. 67 static Optional<Type> convertSparseTensorTypes(Type type) { 68 if (getSparseTensorEncoding(type) != nullptr) 69 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 70 return llvm::None; 71 } 72 }; 73 74 struct SparseTensorConversionPass 75 : public SparseTensorConversionBase<SparseTensorConversionPass> { 76 77 SparseTensorConversionPass() = default; 78 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 79 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 80 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 81 } 82 83 void runOnOperation() override { 84 auto *ctx = &getContext(); 85 RewritePatternSet patterns(ctx); 86 SparseTensorTypeConverter converter; 87 ConversionTarget target(*ctx); 88 // Everything in the sparse dialect must go! 89 target.addIllegalDialect<SparseTensorDialect>(); 90 // All dynamic rules below accept new function, call, return, and tensor 91 // dim and cast operations as legal output of the rewriting provided that 92 // all sparse tensor types have been fully rewritten. 93 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 94 return converter.isSignatureLegal(op.getFunctionType()); 95 }); 96 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 97 return converter.isSignatureLegal(op.getCalleeType()); 98 }); 99 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 100 return converter.isLegal(op.getOperandTypes()); 101 }); 102 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 103 return converter.isLegal(op.getOperandTypes()); 104 }); 105 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 106 return converter.isLegal(op.getOperand().getType()); 107 }); 108 // The following operations and dialects may be introduced by the 109 // rewriting rules, and are therefore marked as legal. 110 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 111 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 112 tensor::ExtractOp>(); 113 target 114 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 115 memref::MemRefDialect, scf::SCFDialect>(); 116 // Translate strategy flags to strategy options. 117 SparseTensorConversionOptions options( 118 sparseToSparseConversionStrategy(sparseToSparse)); 119 // Populate with rules and apply rewriting rules. 120 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 121 converter); 122 populateCallOpTypeConversionPattern(patterns, converter); 123 populateSparseTensorConversionPatterns(converter, patterns, options); 124 if (failed(applyPartialConversion(getOperation(), target, 125 std::move(patterns)))) 126 signalPassFailure(); 127 } 128 }; 129 130 } // namespace 131 132 SparseParallelizationStrategy 133 mlir::sparseParallelizationStrategy(int32_t flag) { 134 switch (flag) { 135 default: 136 return SparseParallelizationStrategy::kNone; 137 case 1: 138 return SparseParallelizationStrategy::kDenseOuterLoop; 139 case 2: 140 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 141 case 3: 142 return SparseParallelizationStrategy::kDenseAnyLoop; 143 case 4: 144 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 145 } 146 } 147 148 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 149 switch (flag) { 150 default: 151 return SparseVectorizationStrategy::kNone; 152 case 1: 153 return SparseVectorizationStrategy::kDenseInnerLoop; 154 case 2: 155 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 156 } 157 } 158 159 SparseToSparseConversionStrategy 160 mlir::sparseToSparseConversionStrategy(int32_t flag) { 161 switch (flag) { 162 default: 163 return SparseToSparseConversionStrategy::kAuto; 164 case 1: 165 return SparseToSparseConversionStrategy::kViaCOO; 166 case 2: 167 return SparseToSparseConversionStrategy::kDirect; 168 } 169 } 170 171 std::unique_ptr<Pass> mlir::createSparsificationPass() { 172 return std::make_unique<SparsificationPass>(); 173 } 174 175 std::unique_ptr<Pass> 176 mlir::createSparsificationPass(const SparsificationOptions &options) { 177 return std::make_unique<SparsificationPass>(options); 178 } 179 180 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 181 return std::make_unique<SparseTensorConversionPass>(); 182 } 183 184 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 185 const SparseTensorConversionOptions &options) { 186 return std::make_unique<SparseTensorConversionPass>(options); 187 } 188