1 //===- SparseTensorLowering.cpp - Sparse tensor primitives lowering -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the lowering 11 // simple. In principle, these primitives could also be lowered to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 /// Returns function reference (first hit also inserts into module). 29 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 30 ValueRange operands) { 31 MLIRContext *context = op->getContext(); 32 auto module = op->getParentOfType<ModuleOp>(); 33 auto func = module.lookupSymbol<FuncOp>(name); 34 if (!func) { 35 OpBuilder moduleBuilder(module.getBodyRegion()); 36 moduleBuilder 37 .create<FuncOp>(op->getLoc(), name, 38 FunctionType::get(context, operands.getTypes(), result)) 39 .setPrivate(); 40 } 41 return SymbolRefAttr::get(context, name); 42 } 43 44 /// Sparse conversion rule to remove opaque pointer cast. 45 class SparseTensorFromPointerConverter 46 : public OpConversionPattern<sparse_tensor::FromPointerOp> { 47 using OpConversionPattern::OpConversionPattern; 48 LogicalResult 49 matchAndRewrite(sparse_tensor::FromPointerOp op, ArrayRef<Value> operands, 50 ConversionPatternRewriter &rewriter) const override { 51 rewriter.replaceOp(op, operands[0]); 52 return success(); 53 } 54 }; 55 56 /// Sparse conversion rule for dimension accesses. 57 class SparseTensorToDimSizeConverter 58 : public OpConversionPattern<memref::DimOp> { 59 public: 60 using OpConversionPattern::OpConversionPattern; 61 LogicalResult 62 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 63 ConversionPatternRewriter &rewriter) const override { 64 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 65 return failure(); 66 Type resType = op.getType(); 67 StringRef name = "sparseDimSize"; 68 rewriter.replaceOpWithNewOp<CallOp>( 69 op, resType, getFunc(op, name, resType, operands), operands); 70 return success(); 71 } 72 }; 73 74 /// Sparse conversion rule for pointer accesses. 75 class SparseTensorToPointersConverter 76 : public OpConversionPattern<sparse_tensor::ToPointersOp> { 77 public: 78 using OpConversionPattern::OpConversionPattern; 79 LogicalResult 80 matchAndRewrite(sparse_tensor::ToPointersOp op, ArrayRef<Value> operands, 81 ConversionPatternRewriter &rewriter) const override { 82 Type resType = op.getType(); 83 Type eltType = resType.cast<ShapedType>().getElementType(); 84 StringRef name; 85 if (eltType.isIndex() || eltType.isInteger(64)) 86 name = "sparsePointers64"; 87 else if (eltType.isInteger(32)) 88 name = "sparsePointers32"; 89 else if (eltType.isInteger(16)) 90 name = "sparsePointers16"; 91 else if (eltType.isInteger(8)) 92 name = "sparsePointers8"; 93 else 94 return failure(); 95 rewriter.replaceOpWithNewOp<CallOp>( 96 op, resType, getFunc(op, name, resType, operands), operands); 97 return success(); 98 } 99 }; 100 101 /// Sparse conversion rule for index accesses. 102 class SparseTensorToIndicesConverter 103 : public OpConversionPattern<sparse_tensor::ToIndicesOp> { 104 public: 105 using OpConversionPattern::OpConversionPattern; 106 LogicalResult 107 matchAndRewrite(sparse_tensor::ToIndicesOp op, ArrayRef<Value> operands, 108 ConversionPatternRewriter &rewriter) const override { 109 Type resType = op.getType(); 110 Type eltType = resType.cast<ShapedType>().getElementType(); 111 StringRef name; 112 if (eltType.isIndex() || eltType.isInteger(64)) 113 name = "sparseIndices64"; 114 else if (eltType.isInteger(32)) 115 name = "sparseIndices32"; 116 else if (eltType.isInteger(16)) 117 name = "sparseIndices16"; 118 else if (eltType.isInteger(8)) 119 name = "sparseIndices8"; 120 else 121 return failure(); 122 rewriter.replaceOpWithNewOp<CallOp>( 123 op, resType, getFunc(op, name, resType, operands), operands); 124 return success(); 125 } 126 }; 127 128 /// Sparse conversion rule for value accesses. 129 class SparseTensorToValuesConverter 130 : public OpConversionPattern<sparse_tensor::ToValuesOp> { 131 public: 132 using OpConversionPattern::OpConversionPattern; 133 LogicalResult 134 matchAndRewrite(sparse_tensor::ToValuesOp op, ArrayRef<Value> operands, 135 ConversionPatternRewriter &rewriter) const override { 136 Type resType = op.getType(); 137 Type eltType = resType.cast<ShapedType>().getElementType(); 138 StringRef name; 139 if (eltType.isF64()) 140 name = "sparseValuesF64"; 141 else if (eltType.isF32()) 142 name = "sparseValuesF32"; 143 else if (eltType.isInteger(32)) 144 name = "sparseValuesI32"; 145 else if (eltType.isInteger(16)) 146 name = "sparseValuesI16"; 147 else if (eltType.isInteger(8)) 148 name = "sparseValuesI8"; 149 else 150 return failure(); 151 rewriter.replaceOpWithNewOp<CallOp>( 152 op, resType, getFunc(op, name, resType, operands), operands); 153 return success(); 154 } 155 }; 156 157 } // namespace 158 159 /// Populates the given patterns list with conversion rules required for 160 /// the sparsification of linear algebra operations. 161 void mlir::populateSparseTensorConversionPatterns(RewritePatternSet &patterns) { 162 patterns.add<SparseTensorFromPointerConverter, SparseTensorToDimSizeConverter, 163 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 164 SparseTensorToValuesConverter>(patterns.getContext()); 165 } 166