1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Internal encoding of primary storage. Keep this enum consistent
31 /// with the equivalent enum in the sparse runtime support library.
32 enum PrimaryTypeEnum : uint64_t {
33   kF64 = 1,
34   kF32 = 2,
35   kI64 = 3,
36   kI32 = 4,
37   kI16 = 5,
38   kI8 = 6
39 };
40 
41 /// Returns internal type encoding for overhead storage. Keep these
42 /// values consistent with the sparse runtime support library.
43 static unsigned getOverheadTypeEncoding(unsigned width) {
44   switch (width) {
45   default:
46     return 1;
47   case 32:
48     return 2;
49   case 16:
50     return 3;
51   case 8:
52     return 4;
53   }
54 }
55 
56 /// Returns internal dimension level type encoding. Keep these
57 /// values consistent with the sparse runtime support library.
58 static unsigned
59 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
60   switch (dlt) {
61   case SparseTensorEncodingAttr::DimLevelType::Dense:
62     return 0;
63   case SparseTensorEncodingAttr::DimLevelType::Compressed:
64     return 1;
65   case SparseTensorEncodingAttr::DimLevelType::Singleton:
66     return 2;
67   }
68   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
69 }
70 
71 /// Returns integers of given width and values as a constant tensor.
72 /// We cast the static shape into a dynamic shape to ensure that the
73 /// method signature remains uniform accross different tensor dimensions.
74 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
75                        Location loc, ArrayRef<APInt> values) {
76   Type etp = rewriter.getIntegerType(width);
77   unsigned sz = values.size();
78   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
79   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
80   auto elts =
81       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
82   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
83 }
84 
85 /// Returns function reference (first hit also inserts into module).
86 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
87                                  ValueRange operands) {
88   MLIRContext *context = op->getContext();
89   auto module = op->getParentOfType<ModuleOp>();
90   auto func = module.lookupSymbol<FuncOp>(name);
91   if (!func) {
92     OpBuilder moduleBuilder(module.getBodyRegion());
93     moduleBuilder
94         .create<FuncOp>(op->getLoc(), name,
95                         FunctionType::get(context, operands.getTypes(), result))
96         .setPrivate();
97   }
98   return SymbolRefAttr::get(context, name);
99 }
100 
101 /// Sparse conversion rule for returns.
102 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
103 public:
104   using OpConversionPattern::OpConversionPattern;
105   LogicalResult
106   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
107                   ConversionPatternRewriter &rewriter) const override {
108     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
109     return success();
110   }
111 };
112 
113 /// Sparse conversion rule for dimension accesses.
114 class SparseTensorToDimSizeConverter
115     : public OpConversionPattern<tensor::DimOp> {
116 public:
117   using OpConversionPattern::OpConversionPattern;
118   LogicalResult
119   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
120                   ConversionPatternRewriter &rewriter) const override {
121     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
122       return failure();
123     Type resType = op.getType();
124     StringRef name = "sparseDimSize";
125     rewriter.replaceOpWithNewOp<CallOp>(
126         op, resType, getFunc(op, name, resType, operands), operands);
127     return success();
128   }
129 };
130 
131 /// Sparse conversion rule for the new operator.
132 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
133   using OpConversionPattern::OpConversionPattern;
134   LogicalResult
135   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
136                   ConversionPatternRewriter &rewriter) const override {
137     Location loc = op.getLoc();
138     Type resType = op.getType();
139     Type eltType = resType.cast<ShapedType>().getElementType();
140     MLIRContext *context = op->getContext();
141     SmallVector<Value, 5> params;
142     // Sparse encoding.
143     auto enc = getSparseTensorEncoding(resType);
144     if (!enc)
145       return failure();
146     // User pointer.
147     params.push_back(operands[0]);
148     // Sparsity annotations in tensor constant form.
149     SmallVector<APInt, 4> attrs;
150     unsigned sz = enc.getDimLevelType().size();
151     for (unsigned i = 0; i < sz; i++)
152       attrs.push_back(
153           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
154     params.push_back(getTensor(rewriter, 8, loc, attrs));
155     // Dimension order permutation array. This is the "identity"
156     // permutation by default, or otherwise the "reverse" permutation
157     // of a given ordering, so that indices can be mapped quickly
158     // to the right position.
159     SmallVector<APInt, 4> perm(sz);
160     AffineMap p = enc.getDimOrdering();
161     if (p) {
162       assert(p.isPermutation() && p.getNumResults() == sz);
163       for (unsigned i = 0; i < sz; i++)
164         perm[p.getDimPosition(i)] = APInt(64, i);
165     } else {
166       for (unsigned i = 0; i < sz; i++)
167         perm[i] = APInt(64, i);
168     }
169     params.push_back(getTensor(rewriter, 64, loc, perm));
170     // Secondary and primary types encoding.
171     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
172     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
173     unsigned primary;
174     if (eltType.isF64())
175       primary = kF64;
176     else if (eltType.isF32())
177       primary = kF32;
178     else if (eltType.isInteger(64))
179       primary = kI64;
180     else if (eltType.isInteger(32))
181       primary = kI32;
182     else if (eltType.isInteger(16))
183       primary = kI16;
184     else if (eltType.isInteger(8))
185       primary = kI8;
186     else
187       return failure();
188     params.push_back(
189         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
190     params.push_back(
191         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
192     params.push_back(
193         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
194     // Generate the call to create new tensor.
195     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
196     StringRef name = "newSparseTensor";
197     rewriter.replaceOpWithNewOp<CallOp>(
198         op, ptrType, getFunc(op, name, ptrType, params), params);
199     return success();
200   }
201 };
202 
203 /// Sparse conversion rule for pointer accesses.
204 class SparseTensorToPointersConverter
205     : public OpConversionPattern<ToPointersOp> {
206 public:
207   using OpConversionPattern::OpConversionPattern;
208   LogicalResult
209   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
210                   ConversionPatternRewriter &rewriter) const override {
211     Type resType = op.getType();
212     Type eltType = resType.cast<ShapedType>().getElementType();
213     StringRef name;
214     if (eltType.isIndex())
215       name = "sparsePointers";
216     else if (eltType.isInteger(64))
217       name = "sparsePointers64";
218     else if (eltType.isInteger(32))
219       name = "sparsePointers32";
220     else if (eltType.isInteger(16))
221       name = "sparsePointers16";
222     else if (eltType.isInteger(8))
223       name = "sparsePointers8";
224     else
225       return failure();
226     rewriter.replaceOpWithNewOp<CallOp>(
227         op, resType, getFunc(op, name, resType, operands), operands);
228     return success();
229   }
230 };
231 
232 /// Sparse conversion rule for index accesses.
233 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
234 public:
235   using OpConversionPattern::OpConversionPattern;
236   LogicalResult
237   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
238                   ConversionPatternRewriter &rewriter) const override {
239     Type resType = op.getType();
240     Type eltType = resType.cast<ShapedType>().getElementType();
241     StringRef name;
242     if (eltType.isIndex())
243       name = "sparseIndices";
244     else if (eltType.isInteger(64))
245       name = "sparseIndices64";
246     else if (eltType.isInteger(32))
247       name = "sparseIndices32";
248     else if (eltType.isInteger(16))
249       name = "sparseIndices16";
250     else if (eltType.isInteger(8))
251       name = "sparseIndices8";
252     else
253       return failure();
254     rewriter.replaceOpWithNewOp<CallOp>(
255         op, resType, getFunc(op, name, resType, operands), operands);
256     return success();
257   }
258 };
259 
260 /// Sparse conversion rule for value accesses.
261 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
262 public:
263   using OpConversionPattern::OpConversionPattern;
264   LogicalResult
265   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
266                   ConversionPatternRewriter &rewriter) const override {
267     Type resType = op.getType();
268     Type eltType = resType.cast<ShapedType>().getElementType();
269     StringRef name;
270     if (eltType.isF64())
271       name = "sparseValuesF64";
272     else if (eltType.isF32())
273       name = "sparseValuesF32";
274     else if (eltType.isInteger(64))
275       name = "sparseValuesI64";
276     else if (eltType.isInteger(32))
277       name = "sparseValuesI32";
278     else if (eltType.isInteger(16))
279       name = "sparseValuesI16";
280     else if (eltType.isInteger(8))
281       name = "sparseValuesI8";
282     else
283       return failure();
284     rewriter.replaceOpWithNewOp<CallOp>(
285         op, resType, getFunc(op, name, resType, operands), operands);
286     return success();
287   }
288 };
289 
290 /// Sparse conversion rule for tensor reconstruction.
291 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
292 public:
293   using OpConversionPattern::OpConversionPattern;
294   LogicalResult
295   // Simply fold the operator into the pointer to the sparse storage scheme.
296   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
297                   ConversionPatternRewriter &rewriter) const override {
298     // Check that all arguments of the tensor reconstruction operators are calls
299     // into the support library that query exactly the same opaque pointer.
300     Value ptr;
301     for (Value op : operands) {
302       if (auto call = op.getDefiningOp<CallOp>()) {
303         Value arg = call.getOperand(0);
304         if (!arg.getType().isa<LLVM::LLVMPointerType>())
305           return failure();
306         if (!ptr)
307           ptr = arg;
308         else if (arg != ptr)
309           return failure();
310       }
311     }
312     // If a single opaque pointer is found, perform the folding.
313     if (!ptr)
314       return failure();
315     rewriter.replaceOp(op, ptr);
316     return success();
317   }
318 };
319 
320 } // namespace
321 
322 /// Populates the given patterns list with conversion rules required for
323 /// the sparsification of linear algebra operations.
324 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
325                                                   RewritePatternSet &patterns) {
326   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
327                SparseTensorNewConverter, SparseTensorToPointersConverter,
328                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
329                SparseTensorToTensorConverter>(typeConverter,
330                                               patterns.getContext());
331 }
332