1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Returns internal type encoding for overhead storage.
31 static unsigned getOverheadTypeEncoding(unsigned width) {
32   switch (width) {
33   default:
34     return 1;
35   case 32:
36     return 2;
37   case 16:
38     return 3;
39   case 8:
40     return 4;
41   }
42 }
43 
44 /// Returns internal dimension level type encoding.
45 static unsigned
46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
47   switch (dlt) {
48   case SparseTensorEncodingAttr::DimLevelType::Dense:
49     return 0;
50   case SparseTensorEncodingAttr::DimLevelType::Compressed:
51     return 1;
52   case SparseTensorEncodingAttr::DimLevelType::Singleton:
53     return 2;
54   }
55 }
56 
57 /// Returns function reference (first hit also inserts into module).
58 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
59                                  ValueRange operands) {
60   MLIRContext *context = op->getContext();
61   auto module = op->getParentOfType<ModuleOp>();
62   auto func = module.lookupSymbol<FuncOp>(name);
63   if (!func) {
64     OpBuilder moduleBuilder(module.getBodyRegion());
65     moduleBuilder
66         .create<FuncOp>(op->getLoc(), name,
67                         FunctionType::get(context, operands.getTypes(), result))
68         .setPrivate();
69   }
70   return SymbolRefAttr::get(context, name);
71 }
72 
73 /// Sparse conversion rule for returns.
74 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
75 public:
76   using OpConversionPattern::OpConversionPattern;
77   LogicalResult
78   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
79                   ConversionPatternRewriter &rewriter) const override {
80     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
81     return success();
82   }
83 };
84 
85 /// Sparse conversion rule for dimension accesses.
86 class SparseTensorToDimSizeConverter
87     : public OpConversionPattern<memref::DimOp> {
88 public:
89   using OpConversionPattern::OpConversionPattern;
90   LogicalResult
91   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
92                   ConversionPatternRewriter &rewriter) const override {
93     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
94       return failure();
95     Type resType = op.getType();
96     StringRef name = "sparseDimSize";
97     rewriter.replaceOpWithNewOp<CallOp>(
98         op, resType, getFunc(op, name, resType, operands), operands);
99     return success();
100   }
101 };
102 
103 /// Sparse conversion rule for the new operator.
104 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
105   using OpConversionPattern::OpConversionPattern;
106   LogicalResult
107   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
108                   ConversionPatternRewriter &rewriter) const override {
109     Location loc = op.getLoc();
110     Type resType = op.getType();
111     Type eltType = resType.cast<ShapedType>().getElementType();
112     MLIRContext *context = op->getContext();
113     SmallVector<Value, 5> params;
114     // Sparse encoding.
115     auto enc = getSparseTensorEncoding(resType);
116     if (!enc)
117       return failure();
118     // User pointer.
119     params.push_back(operands[0]);
120     // Sparsity annotations in tensor constant form. Note that we cast
121     // the static shape into a dynamic shape to ensure that the method
122     // signature remains uniform accross different tensor dimensions.
123     SmallVector<APInt, 4> attrs;
124     unsigned sz = enc.getDimLevelType().size();
125     for (unsigned i = 0; i < sz; i++)
126       attrs.push_back(
127           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
128     Type etp = rewriter.getIntegerType(8);
129     RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
130     RankedTensorType tt2 =
131         RankedTensorType::get({ShapedType::kDynamicSize}, etp);
132     auto elts =
133         rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
134     params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
135     // Seconary and primary types encoding.
136     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
137     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
138     unsigned primary;
139     if (eltType.isF64())
140       primary = 1;
141     else if (eltType.isF32())
142       primary = 2;
143     else if (eltType.isInteger(32))
144       primary = 3;
145     else if (eltType.isInteger(16))
146       primary = 4;
147     else if (eltType.isInteger(8))
148       primary = 5;
149     else
150       return failure();
151     params.push_back(
152         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
153     params.push_back(
154         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
155     params.push_back(
156         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
157     // Generate the call to create new tensor.
158     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
159     StringRef name = "newSparseTensor";
160     rewriter.replaceOpWithNewOp<CallOp>(
161         op, ptrType, getFunc(op, name, ptrType, params), params);
162     return success();
163   }
164 };
165 
166 /// Sparse conversion rule for pointer accesses.
167 class SparseTensorToPointersConverter
168     : public OpConversionPattern<ToPointersOp> {
169 public:
170   using OpConversionPattern::OpConversionPattern;
171   LogicalResult
172   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
173                   ConversionPatternRewriter &rewriter) const override {
174     Type resType = op.getType();
175     Type eltType = resType.cast<ShapedType>().getElementType();
176     StringRef name;
177     if (eltType.isIndex())
178       name = "sparsePointers";
179     else if (eltType.isInteger(64))
180       name = "sparsePointers64";
181     else if (eltType.isInteger(32))
182       name = "sparsePointers32";
183     else if (eltType.isInteger(16))
184       name = "sparsePointers16";
185     else if (eltType.isInteger(8))
186       name = "sparsePointers8";
187     else
188       return failure();
189     rewriter.replaceOpWithNewOp<CallOp>(
190         op, resType, getFunc(op, name, resType, operands), operands);
191     return success();
192   }
193 };
194 
195 /// Sparse conversion rule for index accesses.
196 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
197 public:
198   using OpConversionPattern::OpConversionPattern;
199   LogicalResult
200   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
201                   ConversionPatternRewriter &rewriter) const override {
202     Type resType = op.getType();
203     Type eltType = resType.cast<ShapedType>().getElementType();
204     StringRef name;
205     if (eltType.isIndex())
206       name = "sparseIndices";
207     else if (eltType.isInteger(64))
208       name = "sparseIndices64";
209     else if (eltType.isInteger(32))
210       name = "sparseIndices32";
211     else if (eltType.isInteger(16))
212       name = "sparseIndices16";
213     else if (eltType.isInteger(8))
214       name = "sparseIndices8";
215     else
216       return failure();
217     rewriter.replaceOpWithNewOp<CallOp>(
218         op, resType, getFunc(op, name, resType, operands), operands);
219     return success();
220   }
221 };
222 
223 /// Sparse conversion rule for value accesses.
224 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
225 public:
226   using OpConversionPattern::OpConversionPattern;
227   LogicalResult
228   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
229                   ConversionPatternRewriter &rewriter) const override {
230     Type resType = op.getType();
231     Type eltType = resType.cast<ShapedType>().getElementType();
232     StringRef name;
233     if (eltType.isF64())
234       name = "sparseValuesF64";
235     else if (eltType.isF32())
236       name = "sparseValuesF32";
237     else if (eltType.isInteger(32))
238       name = "sparseValuesI32";
239     else if (eltType.isInteger(16))
240       name = "sparseValuesI16";
241     else if (eltType.isInteger(8))
242       name = "sparseValuesI8";
243     else
244       return failure();
245     rewriter.replaceOpWithNewOp<CallOp>(
246         op, resType, getFunc(op, name, resType, operands), operands);
247     return success();
248   }
249 };
250 
251 } // namespace
252 
253 /// Populates the given patterns list with conversion rules required for
254 /// the sparsification of linear algebra operations.
255 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
256                                                   RewritePatternSet &patterns) {
257   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
258                SparseTensorNewConverter, SparseTensorToPointersConverter,
259                SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
260       typeConverter, patterns.getContext());
261 }
262