1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::sparse_tensor; 23 24 namespace { 25 26 //===----------------------------------------------------------------------===// 27 // Passes declaration. 28 //===----------------------------------------------------------------------===// 29 30 #define GEN_PASS_CLASSES 31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 32 33 //===----------------------------------------------------------------------===// 34 // Passes implementation. 35 //===----------------------------------------------------------------------===// 36 37 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 38 39 SparsificationPass() = default; 40 SparsificationPass(const SparsificationPass &pass) = default; 41 SparsificationPass(const SparsificationOptions &options) { 42 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 43 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 44 vectorLength = options.vectorLength; 45 enableSIMDIndex32 = options.enableSIMDIndex32; 46 enableVLAVectorization = options.enableVLAVectorization; 47 } 48 49 void runOnOperation() override { 50 auto *ctx = &getContext(); 51 RewritePatternSet patterns(ctx); 52 // Translate strategy flags to strategy options. 53 SparsificationOptions options( 54 sparseParallelizationStrategy(parallelization), 55 sparseVectorizationStrategy(vectorization), vectorLength, 56 enableSIMDIndex32, enableVLAVectorization); 57 // Apply rewriting. 58 populateSparsificationPatterns(patterns, options); 59 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 60 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 61 } 62 }; 63 64 class SparseTensorTypeConverter : public TypeConverter { 65 public: 66 SparseTensorTypeConverter() { 67 addConversion([](Type type) { return type; }); 68 addConversion(convertSparseTensorTypes); 69 } 70 // Maps each sparse tensor type to an opaque pointer. 71 static Optional<Type> convertSparseTensorTypes(Type type) { 72 if (getSparseTensorEncoding(type) != nullptr) 73 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 74 return llvm::None; 75 } 76 }; 77 78 struct SparseTensorConversionPass 79 : public SparseTensorConversionBase<SparseTensorConversionPass> { 80 81 SparseTensorConversionPass() = default; 82 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 83 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 84 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 85 } 86 87 void runOnOperation() override { 88 auto *ctx = &getContext(); 89 RewritePatternSet patterns(ctx); 90 SparseTensorTypeConverter converter; 91 ConversionTarget target(*ctx); 92 // Everything in the sparse dialect must go! 93 target.addIllegalDialect<SparseTensorDialect>(); 94 // All dynamic rules below accept new function, call, return, and tensor 95 // dim and cast operations as legal output of the rewriting provided that 96 // all sparse tensor types have been fully rewritten. 97 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 98 return converter.isSignatureLegal(op.getFunctionType()); 99 }); 100 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 101 return converter.isSignatureLegal(op.getCalleeType()); 102 }); 103 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 104 return converter.isLegal(op.getOperandTypes()); 105 }); 106 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 107 return converter.isLegal(op.getOperandTypes()); 108 }); 109 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 110 return converter.isLegal(op.getOperand().getType()); 111 }); 112 // The following operations and dialects may be introduced by the 113 // rewriting rules, and are therefore marked as legal. 114 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 115 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 116 tensor::ExtractOp>(); 117 target 118 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 119 memref::MemRefDialect, scf::SCFDialect>(); 120 // Translate strategy flags to strategy options. 121 SparseTensorConversionOptions options( 122 sparseToSparseConversionStrategy(sparseToSparse)); 123 // Populate with rules and apply rewriting rules. 124 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 125 converter); 126 populateCallOpTypeConversionPattern(patterns, converter); 127 populateSparseTensorConversionPatterns(converter, patterns, options); 128 if (failed(applyPartialConversion(getOperation(), target, 129 std::move(patterns)))) 130 signalPassFailure(); 131 } 132 }; 133 134 } // namespace 135 136 SparseParallelizationStrategy 137 mlir::sparseParallelizationStrategy(int32_t flag) { 138 switch (flag) { 139 default: 140 return SparseParallelizationStrategy::kNone; 141 case 1: 142 return SparseParallelizationStrategy::kDenseOuterLoop; 143 case 2: 144 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 145 case 3: 146 return SparseParallelizationStrategy::kDenseAnyLoop; 147 case 4: 148 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 149 } 150 } 151 152 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 153 switch (flag) { 154 default: 155 return SparseVectorizationStrategy::kNone; 156 case 1: 157 return SparseVectorizationStrategy::kDenseInnerLoop; 158 case 2: 159 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 160 } 161 } 162 163 SparseToSparseConversionStrategy 164 mlir::sparseToSparseConversionStrategy(int32_t flag) { 165 switch (flag) { 166 default: 167 return SparseToSparseConversionStrategy::kAuto; 168 case 1: 169 return SparseToSparseConversionStrategy::kViaCOO; 170 case 2: 171 return SparseToSparseConversionStrategy::kDirect; 172 } 173 } 174 175 std::unique_ptr<Pass> mlir::createSparsificationPass() { 176 return std::make_unique<SparsificationPass>(); 177 } 178 179 std::unique_ptr<Pass> 180 mlir::createSparsificationPass(const SparsificationOptions &options) { 181 return std::make_unique<SparsificationPass>(options); 182 } 183 184 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 185 return std::make_unique<SparseTensorConversionPass>(); 186 } 187 188 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 189 const SparseTensorConversionOptions &options) { 190 return std::make_unique<SparseTensorConversionPass>(options); 191 } 192