1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::sparse_tensor; 24 25 namespace { 26 27 //===----------------------------------------------------------------------===// 28 // Passes declaration. 29 //===----------------------------------------------------------------------===// 30 31 #define GEN_PASS_CLASSES 32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 33 34 //===----------------------------------------------------------------------===// 35 // Passes implementation. 36 //===----------------------------------------------------------------------===// 37 38 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 39 40 SparsificationPass() = default; 41 SparsificationPass(const SparsificationPass &pass) = default; 42 SparsificationPass(const SparsificationOptions &options) { 43 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 44 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 45 vectorLength = options.vectorLength; 46 enableSIMDIndex32 = options.enableSIMDIndex32; 47 enableVLAVectorization = options.enableVLAVectorization; 48 } 49 50 void runOnOperation() override { 51 auto *ctx = &getContext(); 52 // Apply pre-rewriting. 53 RewritePatternSet prePatterns(ctx); 54 populateSparseTensorRewriting(prePatterns); 55 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); 56 // Translate strategy flags to strategy options. 57 SparsificationOptions options( 58 sparseParallelizationStrategy(parallelization), 59 sparseVectorizationStrategy(vectorization), vectorLength, 60 enableSIMDIndex32, enableVLAVectorization); 61 // Apply sparsification and vector cleanup rewriting. 62 RewritePatternSet patterns(ctx); 63 populateSparsificationPatterns(patterns, options); 64 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 65 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 66 } 67 }; 68 69 class SparseTensorTypeConverter : public TypeConverter { 70 public: 71 SparseTensorTypeConverter() { 72 addConversion([](Type type) { return type; }); 73 addConversion(convertSparseTensorTypes); 74 } 75 // Maps each sparse tensor type to an opaque pointer. 76 static Optional<Type> convertSparseTensorTypes(Type type) { 77 if (getSparseTensorEncoding(type) != nullptr) 78 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 79 return llvm::None; 80 } 81 }; 82 83 struct SparseTensorConversionPass 84 : public SparseTensorConversionBase<SparseTensorConversionPass> { 85 86 SparseTensorConversionPass() = default; 87 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 88 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 89 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 90 } 91 92 void runOnOperation() override { 93 auto *ctx = &getContext(); 94 RewritePatternSet patterns(ctx); 95 SparseTensorTypeConverter converter; 96 ConversionTarget target(*ctx); 97 // Everything in the sparse dialect must go! 98 target.addIllegalDialect<SparseTensorDialect>(); 99 // All dynamic rules below accept new function, call, return, and various 100 // tensor and bufferization operations as legal output of the rewriting 101 // provided that all sparse tensor types have been fully rewritten. 102 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 103 return converter.isSignatureLegal(op.getFunctionType()); 104 }); 105 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 106 return converter.isSignatureLegal(op.getCalleeType()); 107 }); 108 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 109 return converter.isLegal(op.getOperandTypes()); 110 }); 111 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 112 return converter.isLegal(op.getOperandTypes()); 113 }); 114 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 115 return converter.isLegal(op.getSource().getType()) && 116 converter.isLegal(op.getDest().getType()); 117 }); 118 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( 119 [&](tensor::ExpandShapeOp op) { 120 return converter.isLegal(op.getSrc().getType()) && 121 converter.isLegal(op.getResult().getType()); 122 }); 123 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( 124 [&](tensor::CollapseShapeOp op) { 125 return converter.isLegal(op.getSrc().getType()) && 126 converter.isLegal(op.getResult().getType()); 127 }); 128 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 129 [&](bufferization::AllocTensorOp op) { 130 return converter.isLegal(op.getType()); 131 }); 132 // The following operations and dialects may be introduced by the 133 // rewriting rules, and are therefore marked as legal. 134 target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp, 135 complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, 136 linalg::YieldOp, tensor::ExtractOp>(); 137 target.addLegalDialect< 138 arith::ArithmeticDialect, bufferization::BufferizationDialect, 139 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); 140 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 141 [&](bufferization::AllocTensorOp op) { 142 // Dense tensors are legal, sparse tensors are not. 143 return !static_cast<bool>(op.getType().getEncoding()); 144 }); 145 // Translate strategy flags to strategy options. 146 SparseTensorConversionOptions options( 147 sparseToSparseConversionStrategy(sparseToSparse)); 148 // Populate with rules and apply rewriting rules. 149 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 150 converter); 151 populateCallOpTypeConversionPattern(patterns, converter); 152 populateSparseTensorConversionPatterns(converter, patterns, options); 153 if (failed(applyPartialConversion(getOperation(), target, 154 std::move(patterns)))) 155 signalPassFailure(); 156 } 157 }; 158 159 } // namespace 160 161 SparseParallelizationStrategy 162 mlir::sparseParallelizationStrategy(int32_t flag) { 163 switch (flag) { 164 default: 165 return SparseParallelizationStrategy::kNone; 166 case 1: 167 return SparseParallelizationStrategy::kDenseOuterLoop; 168 case 2: 169 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 170 case 3: 171 return SparseParallelizationStrategy::kDenseAnyLoop; 172 case 4: 173 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 174 } 175 } 176 177 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 178 switch (flag) { 179 default: 180 return SparseVectorizationStrategy::kNone; 181 case 1: 182 return SparseVectorizationStrategy::kDenseInnerLoop; 183 case 2: 184 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 185 } 186 } 187 188 SparseToSparseConversionStrategy 189 mlir::sparseToSparseConversionStrategy(int32_t flag) { 190 switch (flag) { 191 default: 192 return SparseToSparseConversionStrategy::kAuto; 193 case 1: 194 return SparseToSparseConversionStrategy::kViaCOO; 195 case 2: 196 return SparseToSparseConversionStrategy::kDirect; 197 } 198 } 199 200 std::unique_ptr<Pass> mlir::createSparsificationPass() { 201 return std::make_unique<SparsificationPass>(); 202 } 203 204 std::unique_ptr<Pass> 205 mlir::createSparsificationPass(const SparsificationOptions &options) { 206 return std::make_unique<SparsificationPass>(options); 207 } 208 209 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 210 return std::make_unique<SparseTensorConversionPass>(); 211 } 212 213 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 214 const SparseTensorConversionOptions &options) { 215 return std::make_unique<SparseTensorConversionPass>(options); 216 } 217