1 //===- SparseTensorLowering.cpp - Sparse tensor primitives lowering -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the lowering
11 // simple. In principle, these primitives could also be lowered to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// Returns function reference (first hit also inserts into module).
29 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
30                                  ValueRange operands) {
31   MLIRContext *context = op->getContext();
32   auto module = op->getParentOfType<ModuleOp>();
33   auto func = module.lookupSymbol<FuncOp>(name);
34   if (!func) {
35     OpBuilder moduleBuilder(module.getBodyRegion());
36     moduleBuilder
37         .create<FuncOp>(op->getLoc(), name,
38                         FunctionType::get(context, operands.getTypes(), result))
39         .setPrivate();
40   }
41   return SymbolRefAttr::get(context, name);
42 }
43 
44 /// Sparse conversion rule to remove opaque pointer cast.
45 class SparseTensorFromPointerConverter
46     : public OpConversionPattern<sparse_tensor::FromPointerOp> {
47   using OpConversionPattern::OpConversionPattern;
48   LogicalResult
49   matchAndRewrite(sparse_tensor::FromPointerOp op, ArrayRef<Value> operands,
50                   ConversionPatternRewriter &rewriter) const override {
51     rewriter.replaceOp(op, operands[0]);
52     return success();
53   }
54 };
55 
56 /// Sparse conversion rule for dimension accesses.
57 class SparseTensorToDimSizeConverter
58     : public OpConversionPattern<memref::DimOp> {
59 public:
60   using OpConversionPattern::OpConversionPattern;
61   LogicalResult
62   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
63                   ConversionPatternRewriter &rewriter) const override {
64     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
65       return failure();
66     Type resType = op.getType();
67     StringRef name = "sparseDimSize";
68     rewriter.replaceOpWithNewOp<CallOp>(
69         op, resType, getFunc(op, name, resType, operands), operands);
70     return success();
71   }
72 };
73 
74 /// Sparse conversion rule for pointer accesses.
75 class SparseTensorToPointersConverter
76     : public OpConversionPattern<sparse_tensor::ToPointersOp> {
77 public:
78   using OpConversionPattern::OpConversionPattern;
79   LogicalResult
80   matchAndRewrite(sparse_tensor::ToPointersOp op, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &rewriter) const override {
82     Type resType = op.getType();
83     Type eltType = resType.cast<ShapedType>().getElementType();
84     StringRef name;
85     if (eltType.isIndex() || eltType.isInteger(64))
86       name = "sparsePointers64";
87     else if (eltType.isInteger(32))
88       name = "sparsePointers32";
89     else if (eltType.isInteger(16))
90       name = "sparsePointers16";
91     else if (eltType.isInteger(8))
92       name = "sparsePointers8";
93     else
94       return failure();
95     rewriter.replaceOpWithNewOp<CallOp>(
96         op, resType, getFunc(op, name, resType, operands), operands);
97     return success();
98   }
99 };
100 
101 /// Sparse conversion rule for index accesses.
102 class SparseTensorToIndicesConverter
103     : public OpConversionPattern<sparse_tensor::ToIndicesOp> {
104 public:
105   using OpConversionPattern::OpConversionPattern;
106   LogicalResult
107   matchAndRewrite(sparse_tensor::ToIndicesOp op, ArrayRef<Value> operands,
108                   ConversionPatternRewriter &rewriter) const override {
109     Type resType = op.getType();
110     Type eltType = resType.cast<ShapedType>().getElementType();
111     StringRef name;
112     if (eltType.isIndex() || eltType.isInteger(64))
113       name = "sparseIndices64";
114     else if (eltType.isInteger(32))
115       name = "sparseIndices32";
116     else if (eltType.isInteger(16))
117       name = "sparseIndices16";
118     else if (eltType.isInteger(8))
119       name = "sparseIndices8";
120     else
121       return failure();
122     rewriter.replaceOpWithNewOp<CallOp>(
123         op, resType, getFunc(op, name, resType, operands), operands);
124     return success();
125   }
126 };
127 
128 /// Sparse conversion rule for value accesses.
129 class SparseTensorToValuesConverter
130     : public OpConversionPattern<sparse_tensor::ToValuesOp> {
131 public:
132   using OpConversionPattern::OpConversionPattern;
133   LogicalResult
134   matchAndRewrite(sparse_tensor::ToValuesOp op, ArrayRef<Value> operands,
135                   ConversionPatternRewriter &rewriter) const override {
136     Type resType = op.getType();
137     Type eltType = resType.cast<ShapedType>().getElementType();
138     StringRef name;
139     if (eltType.isF64())
140       name = "sparseValuesF64";
141     else if (eltType.isF32())
142       name = "sparseValuesF32";
143     else if (eltType.isInteger(32))
144       name = "sparseValuesI32";
145     else if (eltType.isInteger(16))
146       name = "sparseValuesI16";
147     else if (eltType.isInteger(8))
148       name = "sparseValuesI8";
149     else
150       return failure();
151     rewriter.replaceOpWithNewOp<CallOp>(
152         op, resType, getFunc(op, name, resType, operands), operands);
153     return success();
154   }
155 };
156 
157 } // namespace
158 
159 /// Populates the given patterns list with conversion rules required for
160 /// the sparsification of linear algebra operations.
161 void mlir::populateSparseTensorConversionPatterns(RewritePatternSet &patterns) {
162   patterns.add<SparseTensorFromPointerConverter, SparseTensorToDimSizeConverter,
163                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
164                SparseTensorToValuesConverter>(patterns.getContext());
165 }
166