1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Returns internal type encoding for primary storage. Keep these
31 /// values consistent with the sparse runtime support library.
32 static unsigned getPrimaryTypeEncoding(Type tp) {
33   if (tp.isF64())
34     return 1;
35   if (tp.isF32())
36     return 2;
37   if (tp.isInteger(64))
38     return 3;
39   if (tp.isInteger(32))
40     return 4;
41   if (tp.isInteger(16))
42     return 5;
43   if (tp.isInteger(8))
44     return 6;
45   return 0;
46 }
47 
48 /// Returns internal type encoding for overhead storage. Keep these
49 /// values consistent with the sparse runtime support library.
50 static unsigned getOverheadTypeEncoding(unsigned width) {
51   switch (width) {
52   default:
53     return 1;
54   case 32:
55     return 2;
56   case 16:
57     return 3;
58   case 8:
59     return 4;
60   }
61 }
62 
63 /// Returns internal dimension level type encoding. Keep these
64 /// values consistent with the sparse runtime support library.
65 static unsigned
66 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
67   switch (dlt) {
68   case SparseTensorEncodingAttr::DimLevelType::Dense:
69     return 0;
70   case SparseTensorEncodingAttr::DimLevelType::Compressed:
71     return 1;
72   case SparseTensorEncodingAttr::DimLevelType::Singleton:
73     return 2;
74   }
75   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
76 }
77 
78 /// Returns integers of given width and values as a constant tensor.
79 /// We cast the static shape into a dynamic shape to ensure that the
80 /// method signature remains uniform accross different tensor dimensions.
81 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
82                        Location loc, ArrayRef<APInt> values) {
83   Type etp = rewriter.getIntegerType(width);
84   unsigned sz = values.size();
85   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
86   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
87   auto elts =
88       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
89   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
90 }
91 
92 /// Returns function reference (first hit also inserts into module).
93 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
94                                  ValueRange operands) {
95   MLIRContext *context = op->getContext();
96   auto module = op->getParentOfType<ModuleOp>();
97   auto func = module.lookupSymbol<FuncOp>(name);
98   if (!func) {
99     OpBuilder moduleBuilder(module.getBodyRegion());
100     moduleBuilder
101         .create<FuncOp>(op->getLoc(), name,
102                         FunctionType::get(context, operands.getTypes(), result))
103         .setPrivate();
104   }
105   return SymbolRefAttr::get(context, name);
106 }
107 
108 /// Sparse conversion rule for returns.
109 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
110 public:
111   using OpConversionPattern::OpConversionPattern;
112   LogicalResult
113   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
114                   ConversionPatternRewriter &rewriter) const override {
115     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
116     return success();
117   }
118 };
119 
120 /// Sparse conversion rule for dimension accesses.
121 class SparseTensorToDimSizeConverter
122     : public OpConversionPattern<tensor::DimOp> {
123 public:
124   using OpConversionPattern::OpConversionPattern;
125   LogicalResult
126   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
127                   ConversionPatternRewriter &rewriter) const override {
128     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
129       return failure();
130     Type resType = op.getType();
131     StringRef name = "sparseDimSize";
132     rewriter.replaceOpWithNewOp<CallOp>(
133         op, resType, getFunc(op, name, resType, operands), operands);
134     return success();
135   }
136 };
137 
138 /// Sparse conversion rule for the new operator.
139 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
140   using OpConversionPattern::OpConversionPattern;
141   LogicalResult
142   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
143                   ConversionPatternRewriter &rewriter) const override {
144     Location loc = op.getLoc();
145     Type resType = op.getType();
146     Type eltType = resType.cast<ShapedType>().getElementType();
147     MLIRContext *context = op->getContext();
148     SmallVector<Value, 5> params;
149     // Sparse encoding.
150     auto enc = getSparseTensorEncoding(resType);
151     if (!enc)
152       return failure();
153     // User pointer.
154     params.push_back(operands[0]);
155     // Sparsity annotations in tensor constant form.
156     SmallVector<APInt, 4> attrs;
157     unsigned sz = enc.getDimLevelType().size();
158     for (unsigned i = 0; i < sz; i++)
159       attrs.push_back(
160           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
161     params.push_back(getTensor(rewriter, 8, loc, attrs));
162     // Dimension order permutation array. This is the "identity"
163     // permutation by default, or otherwise the "reverse" permutation
164     // of a given ordering, so that indices can be mapped quickly
165     // to the right position.
166     SmallVector<APInt, 4> perm(sz);
167     AffineMap p = enc.getDimOrdering();
168     if (p) {
169       assert(p.isPermutation() && p.getNumResults() == sz);
170       for (unsigned i = 0; i < sz; i++)
171         perm[p.getDimPosition(i)] = APInt(64, i);
172     } else {
173       for (unsigned i = 0; i < sz; i++)
174         perm[i] = APInt(64, i);
175     }
176     params.push_back(getTensor(rewriter, 64, loc, perm));
177     // Secondary and primary types encoding.
178     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
179     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
180     unsigned primary = getPrimaryTypeEncoding(eltType);
181     if (!primary)
182       return failure();
183     params.push_back(
184         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
185     params.push_back(
186         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
187     params.push_back(
188         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
189     // Generate the call to create new tensor.
190     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
191     StringRef name = "newSparseTensor";
192     rewriter.replaceOpWithNewOp<CallOp>(
193         op, ptrType, getFunc(op, name, ptrType, params), params);
194     return success();
195   }
196 };
197 
198 /// Sparse conversion rule for the convert operator.
199 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
200   using OpConversionPattern::OpConversionPattern;
201   LogicalResult
202   matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
203                   ConversionPatternRewriter &rewriter) const override {
204     // TODO: implement conversions lowering
205     return failure();
206   }
207 };
208 
209 /// Sparse conversion rule for pointer accesses.
210 class SparseTensorToPointersConverter
211     : public OpConversionPattern<ToPointersOp> {
212 public:
213   using OpConversionPattern::OpConversionPattern;
214   LogicalResult
215   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
216                   ConversionPatternRewriter &rewriter) const override {
217     Type resType = op.getType();
218     Type eltType = resType.cast<ShapedType>().getElementType();
219     StringRef name;
220     if (eltType.isIndex())
221       name = "sparsePointers";
222     else if (eltType.isInteger(64))
223       name = "sparsePointers64";
224     else if (eltType.isInteger(32))
225       name = "sparsePointers32";
226     else if (eltType.isInteger(16))
227       name = "sparsePointers16";
228     else if (eltType.isInteger(8))
229       name = "sparsePointers8";
230     else
231       return failure();
232     rewriter.replaceOpWithNewOp<CallOp>(
233         op, resType, getFunc(op, name, resType, operands), operands);
234     return success();
235   }
236 };
237 
238 /// Sparse conversion rule for index accesses.
239 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
240 public:
241   using OpConversionPattern::OpConversionPattern;
242   LogicalResult
243   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
244                   ConversionPatternRewriter &rewriter) const override {
245     Type resType = op.getType();
246     Type eltType = resType.cast<ShapedType>().getElementType();
247     StringRef name;
248     if (eltType.isIndex())
249       name = "sparseIndices";
250     else if (eltType.isInteger(64))
251       name = "sparseIndices64";
252     else if (eltType.isInteger(32))
253       name = "sparseIndices32";
254     else if (eltType.isInteger(16))
255       name = "sparseIndices16";
256     else if (eltType.isInteger(8))
257       name = "sparseIndices8";
258     else
259       return failure();
260     rewriter.replaceOpWithNewOp<CallOp>(
261         op, resType, getFunc(op, name, resType, operands), operands);
262     return success();
263   }
264 };
265 
266 /// Sparse conversion rule for value accesses.
267 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
268 public:
269   using OpConversionPattern::OpConversionPattern;
270   LogicalResult
271   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
272                   ConversionPatternRewriter &rewriter) const override {
273     Type resType = op.getType();
274     Type eltType = resType.cast<ShapedType>().getElementType();
275     StringRef name;
276     if (eltType.isF64())
277       name = "sparseValuesF64";
278     else if (eltType.isF32())
279       name = "sparseValuesF32";
280     else if (eltType.isInteger(64))
281       name = "sparseValuesI64";
282     else if (eltType.isInteger(32))
283       name = "sparseValuesI32";
284     else if (eltType.isInteger(16))
285       name = "sparseValuesI16";
286     else if (eltType.isInteger(8))
287       name = "sparseValuesI8";
288     else
289       return failure();
290     rewriter.replaceOpWithNewOp<CallOp>(
291         op, resType, getFunc(op, name, resType, operands), operands);
292     return success();
293   }
294 };
295 
296 /// Sparse conversion rule for tensor reconstruction.
297 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
298 public:
299   using OpConversionPattern::OpConversionPattern;
300   LogicalResult
301   // Simply fold the operator into the pointer to the sparse storage scheme.
302   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
303                   ConversionPatternRewriter &rewriter) const override {
304     // Check that all arguments of the tensor reconstruction operators are calls
305     // into the support library that query exactly the same opaque pointer.
306     Value ptr;
307     for (Value op : operands) {
308       if (auto call = op.getDefiningOp<CallOp>()) {
309         Value arg = call.getOperand(0);
310         if (!arg.getType().isa<LLVM::LLVMPointerType>())
311           return failure();
312         if (!ptr)
313           ptr = arg;
314         else if (arg != ptr)
315           return failure();
316       }
317     }
318     // If a single opaque pointer is found, perform the folding.
319     if (!ptr)
320       return failure();
321     rewriter.replaceOp(op, ptr);
322     return success();
323   }
324 };
325 
326 } // namespace
327 
328 /// Populates the given patterns list with conversion rules required for
329 /// the sparsification of linear algebra operations.
330 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
331                                                   RewritePatternSet &patterns) {
332   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
333                SparseTensorNewConverter, SparseTensorConvertConverter,
334                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
335                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
336       typeConverter, patterns.getContext());
337 }
338