1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Returns internal type encoding for overhead storage.
31 static unsigned getOverheadTypeEncoding(unsigned width) {
32   switch (width) {
33   default:
34     return 1;
35   case 32:
36     return 2;
37   case 16:
38     return 3;
39   case 8:
40     return 4;
41   }
42 }
43 
44 /// Returns internal dimension level type encoding.
45 static unsigned
46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
47   switch (dlt) {
48   case SparseTensorEncodingAttr::DimLevelType::Dense:
49     return 0;
50   case SparseTensorEncodingAttr::DimLevelType::Compressed:
51     return 1;
52   case SparseTensorEncodingAttr::DimLevelType::Singleton:
53     return 2;
54   }
55 }
56 
57 /// Returns integers of given width and values as a constant tensor.
58 /// We cast the static shape into a dynamic shape to ensure that the
59 /// method signature remains uniform accross different tensor dimensions.
60 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
61                        Location loc, ArrayRef<APInt> values) {
62   Type etp = rewriter.getIntegerType(width);
63   unsigned sz = values.size();
64   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
65   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
66   auto elts =
67       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
68   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
69 }
70 
71 /// Returns function reference (first hit also inserts into module).
72 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
73                                  ValueRange operands) {
74   MLIRContext *context = op->getContext();
75   auto module = op->getParentOfType<ModuleOp>();
76   auto func = module.lookupSymbol<FuncOp>(name);
77   if (!func) {
78     OpBuilder moduleBuilder(module.getBodyRegion());
79     moduleBuilder
80         .create<FuncOp>(op->getLoc(), name,
81                         FunctionType::get(context, operands.getTypes(), result))
82         .setPrivate();
83   }
84   return SymbolRefAttr::get(context, name);
85 }
86 
87 /// Sparse conversion rule for returns.
88 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
89 public:
90   using OpConversionPattern::OpConversionPattern;
91   LogicalResult
92   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
93                   ConversionPatternRewriter &rewriter) const override {
94     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
95     return success();
96   }
97 };
98 
99 /// Sparse conversion rule for dimension accesses.
100 class SparseTensorToDimSizeConverter
101     : public OpConversionPattern<memref::DimOp> {
102 public:
103   using OpConversionPattern::OpConversionPattern;
104   LogicalResult
105   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
106                   ConversionPatternRewriter &rewriter) const override {
107     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
108       return failure();
109     Type resType = op.getType();
110     StringRef name = "sparseDimSize";
111     rewriter.replaceOpWithNewOp<CallOp>(
112         op, resType, getFunc(op, name, resType, operands), operands);
113     return success();
114   }
115 };
116 
117 /// Sparse conversion rule for the new operator.
118 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
119   using OpConversionPattern::OpConversionPattern;
120   LogicalResult
121   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
122                   ConversionPatternRewriter &rewriter) const override {
123     Location loc = op.getLoc();
124     Type resType = op.getType();
125     Type eltType = resType.cast<ShapedType>().getElementType();
126     MLIRContext *context = op->getContext();
127     SmallVector<Value, 5> params;
128     // Sparse encoding.
129     auto enc = getSparseTensorEncoding(resType);
130     if (!enc)
131       return failure();
132     // User pointer.
133     params.push_back(operands[0]);
134     // Sparsity annotations in tensor constant form.
135     SmallVector<APInt, 4> attrs;
136     unsigned sz = enc.getDimLevelType().size();
137     for (unsigned i = 0; i < sz; i++)
138       attrs.push_back(
139           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
140     params.push_back(getTensor(rewriter, 8, loc, attrs));
141     // Dimension order permutation array. This is the "identity"
142     // permutation by default, or otherwise the "reverse" permutation
143     // of a given ordering, so that indices can be mapped quickly
144     // to the right position.
145     SmallVector<APInt, 4> perm(sz);
146     AffineMap p = enc.getDimOrdering();
147     if (p) {
148       assert(p.isPermutation() && p.getNumResults() == sz);
149       for (unsigned i = 0; i < sz; i++)
150         perm[p.getDimPosition(i)] = APInt(64, i);
151     } else {
152       for (unsigned i = 0; i < sz; i++)
153         perm[i] = APInt(64, i);
154     }
155     params.push_back(getTensor(rewriter, 64, loc, perm));
156     // Secondary and primary types encoding.
157     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
158     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
159     unsigned primary;
160     if (eltType.isF64())
161       primary = 1;
162     else if (eltType.isF32())
163       primary = 2;
164     else if (eltType.isInteger(32))
165       primary = 3;
166     else if (eltType.isInteger(16))
167       primary = 4;
168     else if (eltType.isInteger(8))
169       primary = 5;
170     else
171       return failure();
172     params.push_back(
173         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
174     params.push_back(
175         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
176     params.push_back(
177         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
178     // Generate the call to create new tensor.
179     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
180     StringRef name = "newSparseTensor";
181     rewriter.replaceOpWithNewOp<CallOp>(
182         op, ptrType, getFunc(op, name, ptrType, params), params);
183     return success();
184   }
185 };
186 
187 /// Sparse conversion rule for pointer accesses.
188 class SparseTensorToPointersConverter
189     : public OpConversionPattern<ToPointersOp> {
190 public:
191   using OpConversionPattern::OpConversionPattern;
192   LogicalResult
193   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
194                   ConversionPatternRewriter &rewriter) const override {
195     Type resType = op.getType();
196     Type eltType = resType.cast<ShapedType>().getElementType();
197     StringRef name;
198     if (eltType.isIndex())
199       name = "sparsePointers";
200     else if (eltType.isInteger(64))
201       name = "sparsePointers64";
202     else if (eltType.isInteger(32))
203       name = "sparsePointers32";
204     else if (eltType.isInteger(16))
205       name = "sparsePointers16";
206     else if (eltType.isInteger(8))
207       name = "sparsePointers8";
208     else
209       return failure();
210     rewriter.replaceOpWithNewOp<CallOp>(
211         op, resType, getFunc(op, name, resType, operands), operands);
212     return success();
213   }
214 };
215 
216 /// Sparse conversion rule for index accesses.
217 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
218 public:
219   using OpConversionPattern::OpConversionPattern;
220   LogicalResult
221   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
222                   ConversionPatternRewriter &rewriter) const override {
223     Type resType = op.getType();
224     Type eltType = resType.cast<ShapedType>().getElementType();
225     StringRef name;
226     if (eltType.isIndex())
227       name = "sparseIndices";
228     else if (eltType.isInteger(64))
229       name = "sparseIndices64";
230     else if (eltType.isInteger(32))
231       name = "sparseIndices32";
232     else if (eltType.isInteger(16))
233       name = "sparseIndices16";
234     else if (eltType.isInteger(8))
235       name = "sparseIndices8";
236     else
237       return failure();
238     rewriter.replaceOpWithNewOp<CallOp>(
239         op, resType, getFunc(op, name, resType, operands), operands);
240     return success();
241   }
242 };
243 
244 /// Sparse conversion rule for value accesses.
245 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
246 public:
247   using OpConversionPattern::OpConversionPattern;
248   LogicalResult
249   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
250                   ConversionPatternRewriter &rewriter) const override {
251     Type resType = op.getType();
252     Type eltType = resType.cast<ShapedType>().getElementType();
253     StringRef name;
254     if (eltType.isF64())
255       name = "sparseValuesF64";
256     else if (eltType.isF32())
257       name = "sparseValuesF32";
258     else if (eltType.isInteger(32))
259       name = "sparseValuesI32";
260     else if (eltType.isInteger(16))
261       name = "sparseValuesI16";
262     else if (eltType.isInteger(8))
263       name = "sparseValuesI8";
264     else
265       return failure();
266     rewriter.replaceOpWithNewOp<CallOp>(
267         op, resType, getFunc(op, name, resType, operands), operands);
268     return success();
269   }
270 };
271 
272 } // namespace
273 
274 /// Populates the given patterns list with conversion rules required for
275 /// the sparsification of linear algebra operations.
276 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
277                                                   RewritePatternSet &patterns) {
278   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
279                SparseTensorNewConverter, SparseTensorToPointersConverter,
280                SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
281       typeConverter, patterns.getContext());
282 }
283