1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::sparse_tensor; 24 25 namespace { 26 27 //===----------------------------------------------------------------------===// 28 // Passes declaration. 29 //===----------------------------------------------------------------------===// 30 31 #define GEN_PASS_CLASSES 32 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 33 34 //===----------------------------------------------------------------------===// 35 // Passes implementation. 36 //===----------------------------------------------------------------------===// 37 38 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 39 40 SparsificationPass() = default; 41 SparsificationPass(const SparsificationPass &pass) = default; 42 SparsificationPass(const SparsificationOptions &options) { 43 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 44 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 45 vectorLength = options.vectorLength; 46 enableSIMDIndex32 = options.enableSIMDIndex32; 47 enableVLAVectorization = options.enableVLAVectorization; 48 } 49 50 void runOnOperation() override { 51 auto *ctx = &getContext(); 52 RewritePatternSet patterns(ctx); 53 // Translate strategy flags to strategy options. 54 SparsificationOptions options( 55 sparseParallelizationStrategy(parallelization), 56 sparseVectorizationStrategy(vectorization), vectorLength, 57 enableSIMDIndex32, enableVLAVectorization); 58 // Apply rewriting. 59 populateSparsificationPatterns(patterns, options); 60 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 61 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 62 } 63 }; 64 65 class SparseTensorTypeConverter : public TypeConverter { 66 public: 67 SparseTensorTypeConverter() { 68 addConversion([](Type type) { return type; }); 69 addConversion(convertSparseTensorTypes); 70 } 71 // Maps each sparse tensor type to an opaque pointer. 72 static Optional<Type> convertSparseTensorTypes(Type type) { 73 if (getSparseTensorEncoding(type) != nullptr) 74 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 75 return llvm::None; 76 } 77 }; 78 79 struct SparseTensorConversionPass 80 : public SparseTensorConversionBase<SparseTensorConversionPass> { 81 82 SparseTensorConversionPass() = default; 83 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 84 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 85 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 86 } 87 88 void runOnOperation() override { 89 auto *ctx = &getContext(); 90 RewritePatternSet patterns(ctx); 91 SparseTensorTypeConverter converter; 92 ConversionTarget target(*ctx); 93 // Everything in the sparse dialect must go! 94 target.addIllegalDialect<SparseTensorDialect>(); 95 // All dynamic rules below accept new function, call, return, and various 96 // tensor and bufferization operations as legal output of the rewriting 97 // provided that all sparse tensor types have been fully rewritten. 98 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 99 return converter.isSignatureLegal(op.getFunctionType()); 100 }); 101 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 102 return converter.isSignatureLegal(op.getCalleeType()); 103 }); 104 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 105 return converter.isLegal(op.getOperandTypes()); 106 }); 107 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 108 return converter.isLegal(op.getOperandTypes()); 109 }); 110 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 111 return converter.isLegal(op.getSource().getType()) && 112 converter.isLegal(op.getDest().getType()); 113 }); 114 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( 115 [&](tensor::ExpandShapeOp op) { 116 return converter.isLegal(op.getSrc().getType()) && 117 converter.isLegal(op.getResult().getType()); 118 }); 119 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( 120 [&](tensor::CollapseShapeOp op) { 121 return converter.isLegal(op.getSrc().getType()) && 122 converter.isLegal(op.getResult().getType()); 123 }); 124 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 125 [&](bufferization::AllocTensorOp op) { 126 return converter.isLegal(op.getType()); 127 }); 128 // The following operations and dialects may be introduced by the 129 // rewriting rules, and are therefore marked as legal. 130 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 131 arith::IndexCastOp, complex::ConstantOp, 132 complex::NotEqualOp, linalg::FillOp, linalg::YieldOp, 133 tensor::ExtractOp>(); 134 target 135 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 136 memref::MemRefDialect, scf::SCFDialect>(); 137 // Translate strategy flags to strategy options. 138 SparseTensorConversionOptions options( 139 sparseToSparseConversionStrategy(sparseToSparse)); 140 // Populate with rules and apply rewriting rules. 141 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 142 converter); 143 populateCallOpTypeConversionPattern(patterns, converter); 144 populateSparseTensorConversionPatterns(converter, patterns, options); 145 if (failed(applyPartialConversion(getOperation(), target, 146 std::move(patterns)))) 147 signalPassFailure(); 148 } 149 }; 150 151 } // namespace 152 153 SparseParallelizationStrategy 154 mlir::sparseParallelizationStrategy(int32_t flag) { 155 switch (flag) { 156 default: 157 return SparseParallelizationStrategy::kNone; 158 case 1: 159 return SparseParallelizationStrategy::kDenseOuterLoop; 160 case 2: 161 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 162 case 3: 163 return SparseParallelizationStrategy::kDenseAnyLoop; 164 case 4: 165 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 166 } 167 } 168 169 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 170 switch (flag) { 171 default: 172 return SparseVectorizationStrategy::kNone; 173 case 1: 174 return SparseVectorizationStrategy::kDenseInnerLoop; 175 case 2: 176 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 177 } 178 } 179 180 SparseToSparseConversionStrategy 181 mlir::sparseToSparseConversionStrategy(int32_t flag) { 182 switch (flag) { 183 default: 184 return SparseToSparseConversionStrategy::kAuto; 185 case 1: 186 return SparseToSparseConversionStrategy::kViaCOO; 187 case 2: 188 return SparseToSparseConversionStrategy::kDirect; 189 } 190 } 191 192 std::unique_ptr<Pass> mlir::createSparsificationPass() { 193 return std::make_unique<SparsificationPass>(); 194 } 195 196 std::unique_ptr<Pass> 197 mlir::createSparsificationPass(const SparsificationOptions &options) { 198 return std::make_unique<SparsificationPass>(options); 199 } 200 201 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 202 return std::make_unique<SparseTensorConversionPass>(); 203 } 204 205 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 206 const SparseTensorConversionOptions &options) { 207 return std::make_unique<SparseTensorConversionPass>(options); 208 } 209