1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Passes declaration. 23 //===----------------------------------------------------------------------===// 24 25 #define GEN_PASS_CLASSES 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 27 28 //===----------------------------------------------------------------------===// 29 // Passes implementation. 30 //===----------------------------------------------------------------------===// 31 32 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 33 34 SparsificationPass() = default; 35 SparsificationPass(const SparsificationPass &pass) 36 : SparsificationBase<SparsificationPass>() {} 37 38 /// Returns parallelization strategy given on command line. 39 SparseParallelizationStrategy parallelOption() { 40 switch (parallelization) { 41 default: 42 return SparseParallelizationStrategy::kNone; 43 case 1: 44 return SparseParallelizationStrategy::kDenseOuterLoop; 45 case 2: 46 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 47 case 3: 48 return SparseParallelizationStrategy::kDenseAnyLoop; 49 case 4: 50 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 51 } 52 } 53 54 /// Returns vectorization strategy given on command line. 55 SparseVectorizationStrategy vectorOption() { 56 switch (vectorization) { 57 default: 58 return SparseVectorizationStrategy::kNone; 59 case 1: 60 return SparseVectorizationStrategy::kDenseInnerLoop; 61 case 2: 62 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 63 } 64 } 65 66 void runOnOperation() override { 67 auto *ctx = &getContext(); 68 RewritePatternSet patterns(ctx); 69 // Translate strategy flags to strategy options. 70 SparsificationOptions options(parallelOption(), vectorOption(), 71 vectorLength, enableSIMDIndex32); 72 // Apply rewriting. 73 populateSparsificationPatterns(patterns, options); 74 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 75 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 76 } 77 }; 78 79 class SparseTensorTypeConverter : public TypeConverter { 80 public: 81 SparseTensorTypeConverter() { 82 addConversion([](Type type) { return type; }); 83 addConversion(convertSparseTensorTypes); 84 } 85 // Maps each sparse tensor type to an opaque pointer. 86 static Optional<Type> convertSparseTensorTypes(Type type) { 87 if (getSparseTensorEncoding(type) != nullptr) 88 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 89 return llvm::None; 90 } 91 }; 92 93 struct SparseTensorConversionPass 94 : public SparseTensorConversionBase<SparseTensorConversionPass> { 95 void runOnOperation() override { 96 auto *ctx = &getContext(); 97 RewritePatternSet patterns(ctx); 98 SparseTensorTypeConverter converter; 99 ConversionTarget target(*ctx); 100 // Everything in the sparse dialect must go! 101 target.addIllegalDialect<SparseTensorDialect>(); 102 // All dynamic rules below accept new function, call, return, and tensor 103 // dim and cast operations as legal output of the rewriting provided that 104 // all sparse tensor types have been fully rewritten. 105 target.addDynamicallyLegalOp<FuncOp>( 106 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 107 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 108 return converter.isSignatureLegal(op.getCalleeType()); 109 }); 110 target.addDynamicallyLegalOp<ReturnOp>( 111 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 112 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 113 return converter.isLegal(op.getOperandTypes()); 114 }); 115 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 116 return converter.isLegal(op.getOperand().getType()); 117 }); 118 // The following operations and dialects may be introduced by the 119 // rewriting rules, and are therefore marked as legal. 120 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 121 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 122 tensor::ExtractOp>(); 123 target.addLegalDialect<LLVM::LLVMDialect, memref::MemRefDialect, 124 scf::SCFDialect>(); 125 // Populate with rules and apply rewriting rules. 126 populateFuncOpTypeConversionPattern(patterns, converter); 127 populateCallOpTypeConversionPattern(patterns, converter); 128 populateSparseTensorConversionPatterns(converter, patterns); 129 if (failed(applyPartialConversion(getOperation(), target, 130 std::move(patterns)))) 131 signalPassFailure(); 132 } 133 }; 134 135 } // end anonymous namespace 136 137 std::unique_ptr<Pass> mlir::createSparsificationPass() { 138 return std::make_unique<SparsificationPass>(); 139 } 140 141 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 142 return std::make_unique<SparseTensorConversionPass>(); 143 } 144