1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 namespace { 21 22 //===----------------------------------------------------------------------===// 23 // Passes declaration. 24 //===----------------------------------------------------------------------===// 25 26 #define GEN_PASS_CLASSES 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 28 29 //===----------------------------------------------------------------------===// 30 // Passes implementation. 31 //===----------------------------------------------------------------------===// 32 33 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 34 35 SparsificationPass() = default; 36 SparsificationPass(const SparsificationPass &pass) = default; 37 SparsificationPass(const SparsificationOptions &options) { 38 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 39 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 40 vectorLength = options.vectorLength; 41 enableSIMDIndex32 = options.enableSIMDIndex32; 42 } 43 44 void runOnOperation() override { 45 auto *ctx = &getContext(); 46 RewritePatternSet patterns(ctx); 47 // Translate strategy flags to strategy options. 48 SparsificationOptions options( 49 sparseParallelizationStrategy(parallelization), 50 sparseVectorizationStrategy(vectorization), vectorLength, 51 enableSIMDIndex32); 52 // Apply rewriting. 53 populateSparsificationPatterns(patterns, options); 54 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 55 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 56 } 57 }; 58 59 class SparseTensorTypeConverter : public TypeConverter { 60 public: 61 SparseTensorTypeConverter() { 62 addConversion([](Type type) { return type; }); 63 addConversion(convertSparseTensorTypes); 64 } 65 // Maps each sparse tensor type to an opaque pointer. 66 static Optional<Type> convertSparseTensorTypes(Type type) { 67 if (getSparseTensorEncoding(type) != nullptr) 68 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 69 return llvm::None; 70 } 71 }; 72 73 struct SparseTensorConversionPass 74 : public SparseTensorConversionBase<SparseTensorConversionPass> { 75 void runOnOperation() override { 76 auto *ctx = &getContext(); 77 RewritePatternSet patterns(ctx); 78 SparseTensorTypeConverter converter; 79 ConversionTarget target(*ctx); 80 // Everything in the sparse dialect must go! 81 target.addIllegalDialect<SparseTensorDialect>(); 82 // All dynamic rules below accept new function, call, return, and tensor 83 // dim and cast operations as legal output of the rewriting provided that 84 // all sparse tensor types have been fully rewritten. 85 target.addDynamicallyLegalOp<FuncOp>( 86 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 87 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 88 return converter.isSignatureLegal(op.getCalleeType()); 89 }); 90 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 91 return converter.isLegal(op.getOperandTypes()); 92 }); 93 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 94 return converter.isLegal(op.getOperandTypes()); 95 }); 96 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 97 return converter.isLegal(op.getOperand().getType()); 98 }); 99 // The following operations and dialects may be introduced by the 100 // rewriting rules, and are therefore marked as legal. 101 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 102 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 103 tensor::ExtractOp>(); 104 target 105 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 106 memref::MemRefDialect, scf::SCFDialect>(); 107 // Populate with rules and apply rewriting rules. 108 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 109 converter); 110 populateCallOpTypeConversionPattern(patterns, converter); 111 populateSparseTensorConversionPatterns(converter, patterns); 112 if (failed(applyPartialConversion(getOperation(), target, 113 std::move(patterns)))) 114 signalPassFailure(); 115 } 116 }; 117 118 } // namespace 119 120 SparseParallelizationStrategy 121 mlir::sparseParallelizationStrategy(int32_t flag) { 122 switch (flag) { 123 default: 124 return SparseParallelizationStrategy::kNone; 125 case 1: 126 return SparseParallelizationStrategy::kDenseOuterLoop; 127 case 2: 128 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 129 case 3: 130 return SparseParallelizationStrategy::kDenseAnyLoop; 131 case 4: 132 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 133 } 134 } 135 136 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 137 switch (flag) { 138 default: 139 return SparseVectorizationStrategy::kNone; 140 case 1: 141 return SparseVectorizationStrategy::kDenseInnerLoop; 142 case 2: 143 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 144 } 145 } 146 147 std::unique_ptr<Pass> mlir::createSparsificationPass() { 148 return std::make_unique<SparsificationPass>(); 149 } 150 151 std::unique_ptr<Pass> 152 mlir::createSparsificationPass(const SparsificationOptions &options) { 153 return std::make_unique<SparsificationPass>(options); 154 } 155 156 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 157 return std::make_unique<SparseTensorConversionPass>(); 158 } 159