1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 14 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 namespace { 21 22 //===----------------------------------------------------------------------===// 23 // Passes declaration. 24 //===----------------------------------------------------------------------===// 25 26 #define GEN_PASS_CLASSES 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 28 29 //===----------------------------------------------------------------------===// 30 // Passes implementation. 31 //===----------------------------------------------------------------------===// 32 33 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 34 35 SparsificationPass() = default; 36 SparsificationPass(const SparsificationPass &pass) 37 : SparsificationBase<SparsificationPass>() {} 38 39 /// Returns parallelization strategy given on command line. 40 SparseParallelizationStrategy parallelOption() { 41 switch (parallelization) { 42 default: 43 return SparseParallelizationStrategy::kNone; 44 case 1: 45 return SparseParallelizationStrategy::kDenseOuterLoop; 46 case 2: 47 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 48 case 3: 49 return SparseParallelizationStrategy::kDenseAnyLoop; 50 case 4: 51 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 52 } 53 } 54 55 /// Returns vectorization strategy given on command line. 56 SparseVectorizationStrategy vectorOption() { 57 switch (vectorization) { 58 default: 59 return SparseVectorizationStrategy::kNone; 60 case 1: 61 return SparseVectorizationStrategy::kDenseInnerLoop; 62 case 2: 63 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 64 } 65 } 66 67 void runOnOperation() override { 68 auto *ctx = &getContext(); 69 RewritePatternSet patterns(ctx); 70 // Translate strategy flags to strategy options. 71 SparsificationOptions options(parallelOption(), vectorOption(), 72 vectorLength, enableSIMDIndex32); 73 // Apply rewriting. 74 populateSparsificationPatterns(patterns, options); 75 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 77 } 78 }; 79 80 class SparseTensorTypeConverter : public TypeConverter { 81 public: 82 SparseTensorTypeConverter() { 83 addConversion([](Type type) { return type; }); 84 addConversion(convertSparseTensorTypes); 85 } 86 // Maps each sparse tensor type to an opaque pointer. 87 static Optional<Type> convertSparseTensorTypes(Type type) { 88 if (getSparseTensorEncoding(type) != nullptr) 89 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 90 return llvm::None; 91 } 92 }; 93 94 struct SparseTensorConversionPass 95 : public SparseTensorConversionBase<SparseTensorConversionPass> { 96 void runOnOperation() override { 97 auto *ctx = &getContext(); 98 RewritePatternSet patterns(ctx); 99 SparseTensorTypeConverter converter; 100 ConversionTarget target(*ctx); 101 // Everything in the sparse dialect must go! 102 target.addIllegalDialect<SparseTensorDialect>(); 103 // All dynamic rules below accept new function, call, return, and tensor 104 // dim and cast operations as legal output of the rewriting provided that 105 // all sparse tensor types have been fully rewritten. 106 target.addDynamicallyLegalOp<FuncOp>( 107 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 108 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 109 return converter.isSignatureLegal(op.getCalleeType()); 110 }); 111 target.addDynamicallyLegalOp<ReturnOp>( 112 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 113 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 114 return converter.isLegal(op.getOperandTypes()); 115 }); 116 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 117 return converter.isLegal(op.getOperand().getType()); 118 }); 119 // The following operations and dialects may be introduced by the 120 // rewriting rules, and are therefore marked as legal. 121 target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp, 122 arith::IndexCastOp, linalg::FillOp, linalg::YieldOp, 123 tensor::ExtractOp>(); 124 target 125 .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect, 126 memref::MemRefDialect, scf::SCFDialect>(); 127 // Populate with rules and apply rewriting rules. 128 populateFuncOpTypeConversionPattern(patterns, converter); 129 populateCallOpTypeConversionPattern(patterns, converter); 130 populateSparseTensorConversionPatterns(converter, patterns); 131 if (failed(applyPartialConversion(getOperation(), target, 132 std::move(patterns)))) 133 signalPassFailure(); 134 } 135 }; 136 137 } // namespace 138 139 std::unique_ptr<Pass> mlir::createSparsificationPass() { 140 return std::make_unique<SparsificationPass>(); 141 } 142 143 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 144 return std::make_unique<SparseTensorConversionPass>(); 145 } 146