1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::sparse_tensor; 20 21 namespace { 22 23 //===----------------------------------------------------------------------===// 24 // Passes declaration. 25 //===----------------------------------------------------------------------===// 26 27 #define GEN_PASS_CLASSES 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 29 30 //===----------------------------------------------------------------------===// 31 // Passes implementation. 32 //===----------------------------------------------------------------------===// 33 34 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 35 36 SparsificationPass() = default; 37 SparsificationPass(const SparsificationPass &pass) = default; 38 SparsificationPass(const SparsificationOptions &options) { 39 parallelization = static_cast<int32_t>(options.parallelizationStrategy); 40 vectorization = static_cast<int32_t>(options.vectorizationStrategy); 41 vectorLength = options.vectorLength; 42 enableSIMDIndex32 = options.enableSIMDIndex32; 43 } 44 45 void runOnOperation() override { 46 auto *ctx = &getContext(); 47 RewritePatternSet patterns(ctx); 48 // Translate strategy flags to strategy options. 49 SparsificationOptions options( 50 sparseParallelizationStrategy(parallelization), 51 sparseVectorizationStrategy(vectorization), vectorLength, 52 enableSIMDIndex32); 53 // Apply rewriting. 54 populateSparsificationPatterns(patterns, options); 55 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 56 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 57 } 58 }; 59 60 class SparseTensorTypeConverter : public TypeConverter { 61 public: 62 SparseTensorTypeConverter() { 63 addConversion([](Type type) { return type; }); 64 addConversion(convertSparseTensorTypes); 65 } 66 // Maps each sparse tensor type to an opaque pointer. 67 static Optional<Type> convertSparseTensorTypes(Type type) { 68 if (getSparseTensorEncoding(type) != nullptr) 69 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 70 return llvm::None; 71 } 72 }; 73 74 struct SparseTensorConversionPass 75 : public SparseTensorConversionBase<SparseTensorConversionPass> { 76 void runOnOperation() override { 77 auto *ctx = &getContext(); 78 RewritePatternSet patterns(ctx); 79 SparseTensorTypeConverter converter; 80 ConversionTarget target(*ctx); 81 // Everything in the sparse dialect must go! 82 target.addIllegalDialect<SparseTensorDialect>(); 83 // All dynamic rules below accept new function, call, return, and tensor 84 // dim and cast operations as legal output of the rewriting provided that 85 // all sparse tensor types have been fully rewritten. 86 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 87 return converter.isSignatureLegal(op.getFunctionType()); 88 }); 89 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 90 return converter.isSignatureLegal(op.getCalleeType()); 91 }); 92 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 93 return converter.isLegal(op.getOperandTypes()); 94 }); 95 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 96 return converter.isLegal(op.getOperandTypes()); 97 }); 98 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 99 return converter.isLegal(op.getOperand().getType()); 100 }); 101 // The following operations and dialects may be introduced by the 102 // rewriting rules, and are therefore marked as legal. 103 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 104 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 105 tensor::ExtractOp>(); 106 target 107 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 108 memref::MemRefDialect, scf::SCFDialect>(); 109 // Populate with rules and apply rewriting rules. 110 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 111 converter); 112 populateCallOpTypeConversionPattern(patterns, converter); 113 populateSparseTensorConversionPatterns(converter, patterns); 114 if (failed(applyPartialConversion(getOperation(), target, 115 std::move(patterns)))) 116 signalPassFailure(); 117 } 118 }; 119 120 } // namespace 121 122 SparseParallelizationStrategy 123 mlir::sparseParallelizationStrategy(int32_t flag) { 124 switch (flag) { 125 default: 126 return SparseParallelizationStrategy::kNone; 127 case 1: 128 return SparseParallelizationStrategy::kDenseOuterLoop; 129 case 2: 130 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 131 case 3: 132 return SparseParallelizationStrategy::kDenseAnyLoop; 133 case 4: 134 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 135 } 136 } 137 138 SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) { 139 switch (flag) { 140 default: 141 return SparseVectorizationStrategy::kNone; 142 case 1: 143 return SparseVectorizationStrategy::kDenseInnerLoop; 144 case 2: 145 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 146 } 147 } 148 149 std::unique_ptr<Pass> mlir::createSparsificationPass() { 150 return std::make_unique<SparsificationPass>(); 151 } 152 153 std::unique_ptr<Pass> 154 mlir::createSparsificationPass(const SparsificationOptions &options) { 155 return std::make_unique<SparsificationPass>(options); 156 } 157 158 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 159 return std::make_unique<SparseTensorConversionPass>(); 160 } 161