1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Returns internal type encoding for overhead storage.
31 static unsigned getOverheadTypeEncoding(unsigned width) {
32   switch (width) {
33   default:
34     return 1;
35   case 32:
36     return 2;
37   case 16:
38     return 3;
39   case 8:
40     return 4;
41   }
42 }
43 
44 /// Returns internal dimension level type encoding.
45 static unsigned
46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
47   switch (dlt) {
48   case SparseTensorEncodingAttr::DimLevelType::Dense:
49     return 0;
50   case SparseTensorEncodingAttr::DimLevelType::Compressed:
51     return 1;
52   case SparseTensorEncodingAttr::DimLevelType::Singleton:
53     return 2;
54   }
55   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
56 }
57 
58 /// Returns integers of given width and values as a constant tensor.
59 /// We cast the static shape into a dynamic shape to ensure that the
60 /// method signature remains uniform accross different tensor dimensions.
61 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
62                        Location loc, ArrayRef<APInt> values) {
63   Type etp = rewriter.getIntegerType(width);
64   unsigned sz = values.size();
65   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
66   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
67   auto elts =
68       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
69   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
70 }
71 
72 /// Returns function reference (first hit also inserts into module).
73 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
74                                  ValueRange operands) {
75   MLIRContext *context = op->getContext();
76   auto module = op->getParentOfType<ModuleOp>();
77   auto func = module.lookupSymbol<FuncOp>(name);
78   if (!func) {
79     OpBuilder moduleBuilder(module.getBodyRegion());
80     moduleBuilder
81         .create<FuncOp>(op->getLoc(), name,
82                         FunctionType::get(context, operands.getTypes(), result))
83         .setPrivate();
84   }
85   return SymbolRefAttr::get(context, name);
86 }
87 
88 /// Sparse conversion rule for returns.
89 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
90 public:
91   using OpConversionPattern::OpConversionPattern;
92   LogicalResult
93   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
94                   ConversionPatternRewriter &rewriter) const override {
95     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
96     return success();
97   }
98 };
99 
100 /// Sparse conversion rule for dimension accesses.
101 class SparseTensorToDimSizeConverter
102     : public OpConversionPattern<memref::DimOp> {
103 public:
104   using OpConversionPattern::OpConversionPattern;
105   LogicalResult
106   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
107                   ConversionPatternRewriter &rewriter) const override {
108     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
109       return failure();
110     Type resType = op.getType();
111     StringRef name = "sparseDimSize";
112     rewriter.replaceOpWithNewOp<CallOp>(
113         op, resType, getFunc(op, name, resType, operands), operands);
114     return success();
115   }
116 };
117 
118 /// Sparse conversion rule for the new operator.
119 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
120   using OpConversionPattern::OpConversionPattern;
121   LogicalResult
122   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
123                   ConversionPatternRewriter &rewriter) const override {
124     Location loc = op.getLoc();
125     Type resType = op.getType();
126     Type eltType = resType.cast<ShapedType>().getElementType();
127     MLIRContext *context = op->getContext();
128     SmallVector<Value, 5> params;
129     // Sparse encoding.
130     auto enc = getSparseTensorEncoding(resType);
131     if (!enc)
132       return failure();
133     // User pointer.
134     params.push_back(operands[0]);
135     // Sparsity annotations in tensor constant form.
136     SmallVector<APInt, 4> attrs;
137     unsigned sz = enc.getDimLevelType().size();
138     for (unsigned i = 0; i < sz; i++)
139       attrs.push_back(
140           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
141     params.push_back(getTensor(rewriter, 8, loc, attrs));
142     // Dimension order permutation array. This is the "identity"
143     // permutation by default, or otherwise the "reverse" permutation
144     // of a given ordering, so that indices can be mapped quickly
145     // to the right position.
146     SmallVector<APInt, 4> perm(sz);
147     AffineMap p = enc.getDimOrdering();
148     if (p) {
149       assert(p.isPermutation() && p.getNumResults() == sz);
150       for (unsigned i = 0; i < sz; i++)
151         perm[p.getDimPosition(i)] = APInt(64, i);
152     } else {
153       for (unsigned i = 0; i < sz; i++)
154         perm[i] = APInt(64, i);
155     }
156     params.push_back(getTensor(rewriter, 64, loc, perm));
157     // Secondary and primary types encoding.
158     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
159     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
160     unsigned primary;
161     if (eltType.isF64())
162       primary = 1;
163     else if (eltType.isF32())
164       primary = 2;
165     else if (eltType.isInteger(32))
166       primary = 3;
167     else if (eltType.isInteger(16))
168       primary = 4;
169     else if (eltType.isInteger(8))
170       primary = 5;
171     else
172       return failure();
173     params.push_back(
174         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
175     params.push_back(
176         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
177     params.push_back(
178         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
179     // Generate the call to create new tensor.
180     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
181     StringRef name = "newSparseTensor";
182     rewriter.replaceOpWithNewOp<CallOp>(
183         op, ptrType, getFunc(op, name, ptrType, params), params);
184     return success();
185   }
186 };
187 
188 /// Sparse conversion rule for pointer accesses.
189 class SparseTensorToPointersConverter
190     : public OpConversionPattern<ToPointersOp> {
191 public:
192   using OpConversionPattern::OpConversionPattern;
193   LogicalResult
194   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
195                   ConversionPatternRewriter &rewriter) const override {
196     Type resType = op.getType();
197     Type eltType = resType.cast<ShapedType>().getElementType();
198     StringRef name;
199     if (eltType.isIndex())
200       name = "sparsePointers";
201     else if (eltType.isInteger(64))
202       name = "sparsePointers64";
203     else if (eltType.isInteger(32))
204       name = "sparsePointers32";
205     else if (eltType.isInteger(16))
206       name = "sparsePointers16";
207     else if (eltType.isInteger(8))
208       name = "sparsePointers8";
209     else
210       return failure();
211     rewriter.replaceOpWithNewOp<CallOp>(
212         op, resType, getFunc(op, name, resType, operands), operands);
213     return success();
214   }
215 };
216 
217 /// Sparse conversion rule for index accesses.
218 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
219 public:
220   using OpConversionPattern::OpConversionPattern;
221   LogicalResult
222   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
223                   ConversionPatternRewriter &rewriter) const override {
224     Type resType = op.getType();
225     Type eltType = resType.cast<ShapedType>().getElementType();
226     StringRef name;
227     if (eltType.isIndex())
228       name = "sparseIndices";
229     else if (eltType.isInteger(64))
230       name = "sparseIndices64";
231     else if (eltType.isInteger(32))
232       name = "sparseIndices32";
233     else if (eltType.isInteger(16))
234       name = "sparseIndices16";
235     else if (eltType.isInteger(8))
236       name = "sparseIndices8";
237     else
238       return failure();
239     rewriter.replaceOpWithNewOp<CallOp>(
240         op, resType, getFunc(op, name, resType, operands), operands);
241     return success();
242   }
243 };
244 
245 /// Sparse conversion rule for value accesses.
246 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
247 public:
248   using OpConversionPattern::OpConversionPattern;
249   LogicalResult
250   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
251                   ConversionPatternRewriter &rewriter) const override {
252     Type resType = op.getType();
253     Type eltType = resType.cast<ShapedType>().getElementType();
254     StringRef name;
255     if (eltType.isF64())
256       name = "sparseValuesF64";
257     else if (eltType.isF32())
258       name = "sparseValuesF32";
259     else if (eltType.isInteger(32))
260       name = "sparseValuesI32";
261     else if (eltType.isInteger(16))
262       name = "sparseValuesI16";
263     else if (eltType.isInteger(8))
264       name = "sparseValuesI8";
265     else
266       return failure();
267     rewriter.replaceOpWithNewOp<CallOp>(
268         op, resType, getFunc(op, name, resType, operands), operands);
269     return success();
270   }
271 };
272 
273 } // namespace
274 
275 /// Populates the given patterns list with conversion rules required for
276 /// the sparsification of linear algebra operations.
277 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
278                                                   RewritePatternSet &patterns) {
279   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
280                SparseTensorNewConverter, SparseTensorToPointersConverter,
281                SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
282       typeConverter, patterns.getContext());
283 }
284