1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Passes declaration. 23 //===----------------------------------------------------------------------===// 24 25 #define GEN_PASS_CLASSES 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 27 28 //===----------------------------------------------------------------------===// 29 // Passes implementation. 30 //===----------------------------------------------------------------------===// 31 32 struct SparsificationPass : public SparsificationBase<SparsificationPass> { 33 34 SparsificationPass() = default; 35 SparsificationPass(const SparsificationPass &pass) 36 : SparsificationBase<SparsificationPass>() {} 37 38 /// Returns parallelization strategy given on command line. 39 SparseParallelizationStrategy parallelOption() { 40 switch (parallelization) { 41 default: 42 return SparseParallelizationStrategy::kNone; 43 case 1: 44 return SparseParallelizationStrategy::kDenseOuterLoop; 45 case 2: 46 return SparseParallelizationStrategy::kAnyStorageOuterLoop; 47 case 3: 48 return SparseParallelizationStrategy::kDenseAnyLoop; 49 case 4: 50 return SparseParallelizationStrategy::kAnyStorageAnyLoop; 51 } 52 } 53 54 /// Returns vectorization strategy given on command line. 55 SparseVectorizationStrategy vectorOption() { 56 switch (vectorization) { 57 default: 58 return SparseVectorizationStrategy::kNone; 59 case 1: 60 return SparseVectorizationStrategy::kDenseInnerLoop; 61 case 2: 62 return SparseVectorizationStrategy::kAnyStorageInnerLoop; 63 } 64 } 65 66 void runOnOperation() override { 67 auto *ctx = &getContext(); 68 RewritePatternSet patterns(ctx); 69 // Translate strategy flags to strategy options. 70 SparsificationOptions options(parallelOption(), vectorOption(), 71 vectorLength, enableSIMDIndex32); 72 // Apply rewriting. 73 populateSparsificationPatterns(patterns, options); 74 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 75 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 76 } 77 }; 78 79 class SparseTensorTypeConverter : public TypeConverter { 80 public: 81 SparseTensorTypeConverter() { 82 addConversion([](Type type) { return type; }); 83 addConversion(convertSparseTensorTypes); 84 } 85 // Maps each sparse tensor type to an opaque pointer. 86 static Optional<Type> convertSparseTensorTypes(Type type) { 87 if (getSparseTensorEncoding(type) != nullptr) 88 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 89 return llvm::None; 90 } 91 }; 92 93 struct SparseTensorConversionPass 94 : public SparseTensorConversionBase<SparseTensorConversionPass> { 95 void runOnOperation() override { 96 auto *ctx = &getContext(); 97 RewritePatternSet patterns(ctx); 98 SparseTensorTypeConverter converter; 99 ConversionTarget target(*ctx); 100 target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp, 101 ToTensorOp>(); 102 // All dynamic rules below accept new function, call, return, and dimop 103 // operations as legal output of the rewriting provided that all sparse 104 // tensor types have been fully rewritten. 105 target.addDynamicallyLegalOp<FuncOp>( 106 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 107 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 108 return converter.isSignatureLegal(op.getCalleeType()); 109 }); 110 target.addDynamicallyLegalOp<ReturnOp>( 111 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 112 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 113 return converter.isLegal(op.getOperandTypes()); 114 }); 115 // The following operations and dialects may be introduced by the 116 // rewriting rules, and are therefore marked as legal. 117 target.addLegalOp<arith::ConstantOp, ConstantOp, arith::IndexCastOp, 118 tensor::CastOp, tensor::ExtractOp, arith::CmpFOp, 119 arith::CmpIOp>(); 120 target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect, 121 memref::MemRefDialect>(); 122 // Populate with rules and apply rewriting rules. 123 populateFuncOpTypeConversionPattern(patterns, converter); 124 populateCallOpTypeConversionPattern(patterns, converter); 125 populateSparseTensorConversionPatterns(converter, patterns); 126 if (failed(applyPartialConversion(getOperation(), target, 127 std::move(patterns)))) 128 signalPassFailure(); 129 } 130 }; 131 132 } // end anonymous namespace 133 134 std::unique_ptr<Pass> mlir::createSparsificationPass() { 135 return std::make_unique<SparsificationPass>(); 136 } 137 138 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 139 return std::make_unique<SparseTensorConversionPass>(); 140 } 141