1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::sparse_tensor; 23 24 namespace { 25 26 //===----------------------------------------------------------------------===// 27 // Passes declaration. 28 //===----------------------------------------------------------------------===// 29 30 #define GEN_PASS_CLASSES 31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 32 33 //===----------------------------------------------------------------------===// 34 // Passes implementation. 35 //===----------------------------------------------------------------------===// 36 37 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 38 39 SparsificationPass() = default; 40 SparsificationPass(const SparsificationPass &pass) = default; 41 SparsificationPass(const SparsificationOptions &options) { 42 parallelization = options.parallelizationStrategy; 43 vectorization = options.vectorizationStrategy; 44 vectorLength = options.vectorLength; 45 enableSIMDIndex32 = options.enableSIMDIndex32; 46 enableVLAVectorization = options.enableVLAVectorization; 47 } 48 49 void runOnOperation() override { 50 auto *ctx = &getContext(); 51 RewritePatternSet patterns(ctx); 52 // Translate strategy flags to strategy options. 53 SparsificationOptions options(parallelization, vectorization, vectorLength, 54 enableSIMDIndex32, enableVLAVectorization); 55 // Apply rewriting. 56 populateSparsificationPatterns(patterns, options); 57 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 58 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 59 } 60 }; 61 62 class SparseTensorTypeConverter : public TypeConverter { 63 public: 64 SparseTensorTypeConverter() { 65 addConversion([](Type type) { return type; }); 66 addConversion(convertSparseTensorTypes); 67 } 68 // Maps each sparse tensor type to an opaque pointer. 69 static Optional<Type> convertSparseTensorTypes(Type type) { 70 if (getSparseTensorEncoding(type) != nullptr) 71 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 72 return llvm::None; 73 } 74 }; 75 76 struct SparseTensorConversionPass 77 : public SparseTensorConversionBase<SparseTensorConversionPass> { 78 79 SparseTensorConversionPass() = default; 80 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 81 SparseTensorConversionPass(const SparseTensorConversionOptions &options) { 82 sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy); 83 } 84 85 void runOnOperation() override { 86 auto *ctx = &getContext(); 87 RewritePatternSet patterns(ctx); 88 SparseTensorTypeConverter converter; 89 ConversionTarget target(*ctx); 90 // Everything in the sparse dialect must go! 91 target.addIllegalDialect<SparseTensorDialect>(); 92 // All dynamic rules below accept new function, call, return, and tensor 93 // dim and cast operations as legal output of the rewriting provided that 94 // all sparse tensor types have been fully rewritten. 95 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 96 return converter.isSignatureLegal(op.getFunctionType()); 97 }); 98 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 99 return converter.isSignatureLegal(op.getCalleeType()); 100 }); 101 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 102 return converter.isLegal(op.getOperandTypes()); 103 }); 104 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 105 return converter.isLegal(op.getOperandTypes()); 106 }); 107 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 108 return converter.isLegal(op.getOperand().getType()); 109 }); 110 // The following operations and dialects may be introduced by the 111 // rewriting rules, and are therefore marked as legal. 112 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 113 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 114 tensor::ExtractOp>(); 115 target 116 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 117 memref::MemRefDialect, scf::SCFDialect>(); 118 // Translate strategy flags to strategy options. 119 SparseTensorConversionOptions options( 120 sparseToSparseConversionStrategy(sparseToSparse)); 121 // Populate with rules and apply rewriting rules. 122 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 123 converter); 124 populateCallOpTypeConversionPattern(patterns, converter); 125 populateSparseTensorConversionPatterns(converter, patterns, options); 126 if (failed(applyPartialConversion(getOperation(), target, 127 std::move(patterns)))) 128 signalPassFailure(); 129 } 130 }; 131 132 } // namespace 133 134 SparseToSparseConversionStrategy 135 mlir::sparseToSparseConversionStrategy(int32_t flag) { 136 switch (flag) { 137 default: 138 return SparseToSparseConversionStrategy::kAuto; 139 case 1: 140 return SparseToSparseConversionStrategy::kViaCOO; 141 case 2: 142 return SparseToSparseConversionStrategy::kDirect; 143 } 144 } 145 146 std::unique_ptr<Pass> mlir::createSparsificationPass() { 147 return std::make_unique<SparsificationPass>(); 148 } 149 150 std::unique_ptr<Pass> 151 mlir::createSparsificationPass(const SparsificationOptions &options) { 152 return std::make_unique<SparsificationPass>(options); 153 } 154 155 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 156 return std::make_unique<SparseTensorConversionPass>(); 157 } 158 159 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass( 160 const SparseTensorConversionOptions &options) { 161 return std::make_unique<SparseTensorConversionPass>(options); 162 } 163