1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 using namespace mlir::sparse_tensor;
26 
27 namespace {
28 
29 /// Returns internal type encoding for overhead storage.
30 static unsigned getOverheadTypeEncoding(unsigned width) {
31   switch (width) {
32   default:
33     return 1;
34   case 32:
35     return 2;
36   case 16:
37     return 3;
38   case 8:
39     return 4;
40   }
41 }
42 
43 /// Returns function reference (first hit also inserts into module).
44 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
45                                  ValueRange operands) {
46   MLIRContext *context = op->getContext();
47   auto module = op->getParentOfType<ModuleOp>();
48   auto func = module.lookupSymbol<FuncOp>(name);
49   if (!func) {
50     OpBuilder moduleBuilder(module.getBodyRegion());
51     moduleBuilder
52         .create<FuncOp>(op->getLoc(), name,
53                         FunctionType::get(context, operands.getTypes(), result))
54         .setPrivate();
55   }
56   return SymbolRefAttr::get(context, name);
57 }
58 
59 /// Sparse conversion rule for returns.
60 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
61 public:
62   using OpConversionPattern::OpConversionPattern;
63   LogicalResult
64   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
65                   ConversionPatternRewriter &rewriter) const override {
66     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
67     return success();
68   }
69 };
70 
71 /// Sparse conversion rule for dimension accesses.
72 class SparseTensorToDimSizeConverter
73     : public OpConversionPattern<memref::DimOp> {
74 public:
75   using OpConversionPattern::OpConversionPattern;
76   LogicalResult
77   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
78                   ConversionPatternRewriter &rewriter) const override {
79     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
80       return failure();
81     Type resType = op.getType();
82     StringRef name = "sparseDimSize";
83     rewriter.replaceOpWithNewOp<CallOp>(
84         op, resType, getFunc(op, name, resType, operands), operands);
85     return success();
86   }
87 };
88 
89 /// Sparse conversion rule for the new operator.
90 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
91   using OpConversionPattern::OpConversionPattern;
92   LogicalResult
93   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
94                   ConversionPatternRewriter &rewriter) const override {
95     Location loc = op.getLoc();
96     Type resType = op.getType();
97     Type eltType = resType.cast<ShapedType>().getElementType();
98     MLIRContext *context = op->getContext();
99     SmallVector<Value, 5> params;
100     // Sparse encoding.
101     auto enc = getSparseTensorEncoding(resType);
102     if (!enc)
103       return failure();
104     // User pointer.
105     params.push_back(operands[0]);
106     // Sparsity annotations.
107     SmallVector<bool, 4> attrs;
108     unsigned sz = enc.getDimLevelType().size();
109     for (unsigned i = 0; i < sz; i++)
110       attrs.push_back(enc.getDimLevelType()[i] ==
111                       SparseTensorEncodingAttr::DimLevelType::Compressed);
112     auto elts = DenseElementsAttr::get(
113         RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
114     params.push_back(rewriter.create<ConstantOp>(loc, elts));
115     // Seconary and primary types encoding.
116     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
117     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
118     unsigned primary;
119     if (eltType.isF64())
120       primary = 1;
121     else if (eltType.isF32())
122       primary = 2;
123     else if (eltType.isInteger(32))
124       primary = 3;
125     else if (eltType.isInteger(16))
126       primary = 4;
127     else if (eltType.isInteger(8))
128       primary = 5;
129     else
130       return failure();
131     params.push_back(
132         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
133     params.push_back(
134         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
135     params.push_back(
136         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
137     // Generate the call to create new tensor.
138     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
139     StringRef name = "newSparseTensor";
140     rewriter.replaceOpWithNewOp<CallOp>(
141         op, ptrType, getFunc(op, name, ptrType, params), params);
142     return success();
143   }
144 };
145 
146 /// Sparse conversion rule for pointer accesses.
147 class SparseTensorToPointersConverter
148     : public OpConversionPattern<ToPointersOp> {
149 public:
150   using OpConversionPattern::OpConversionPattern;
151   LogicalResult
152   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
153                   ConversionPatternRewriter &rewriter) const override {
154     Type resType = op.getType();
155     Type eltType = resType.cast<ShapedType>().getElementType();
156     StringRef name;
157     if (eltType.isIndex())
158       name = "sparsePointers";
159     else if (eltType.isInteger(64))
160       name = "sparsePointers64";
161     else if (eltType.isInteger(32))
162       name = "sparsePointers32";
163     else if (eltType.isInteger(16))
164       name = "sparsePointers16";
165     else if (eltType.isInteger(8))
166       name = "sparsePointers8";
167     else
168       return failure();
169     rewriter.replaceOpWithNewOp<CallOp>(
170         op, resType, getFunc(op, name, resType, operands), operands);
171     return success();
172   }
173 };
174 
175 /// Sparse conversion rule for index accesses.
176 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
177 public:
178   using OpConversionPattern::OpConversionPattern;
179   LogicalResult
180   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
181                   ConversionPatternRewriter &rewriter) const override {
182     Type resType = op.getType();
183     Type eltType = resType.cast<ShapedType>().getElementType();
184     StringRef name;
185     if (eltType.isIndex())
186       name = "sparseIndices";
187     else if (eltType.isInteger(64))
188       name = "sparseIndices64";
189     else if (eltType.isInteger(32))
190       name = "sparseIndices32";
191     else if (eltType.isInteger(16))
192       name = "sparseIndices16";
193     else if (eltType.isInteger(8))
194       name = "sparseIndices8";
195     else
196       return failure();
197     rewriter.replaceOpWithNewOp<CallOp>(
198         op, resType, getFunc(op, name, resType, operands), operands);
199     return success();
200   }
201 };
202 
203 /// Sparse conversion rule for value accesses.
204 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
205 public:
206   using OpConversionPattern::OpConversionPattern;
207   LogicalResult
208   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
209                   ConversionPatternRewriter &rewriter) const override {
210     Type resType = op.getType();
211     Type eltType = resType.cast<ShapedType>().getElementType();
212     StringRef name;
213     if (eltType.isF64())
214       name = "sparseValuesF64";
215     else if (eltType.isF32())
216       name = "sparseValuesF32";
217     else if (eltType.isInteger(32))
218       name = "sparseValuesI32";
219     else if (eltType.isInteger(16))
220       name = "sparseValuesI16";
221     else if (eltType.isInteger(8))
222       name = "sparseValuesI8";
223     else
224       return failure();
225     rewriter.replaceOpWithNewOp<CallOp>(
226         op, resType, getFunc(op, name, resType, operands), operands);
227     return success();
228   }
229 };
230 
231 } // namespace
232 
233 /// Populates the given patterns list with conversion rules required for
234 /// the sparsification of linear algebra operations.
235 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
236                                                   RewritePatternSet &patterns) {
237   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
238                SparseTensorNewConverter, SparseTensorToPointersConverter,
239                SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
240       typeConverter, patterns.getContext());
241 }
242