1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Returns integers of given width and values as a constant tensor.
85 /// We cast the static shape into a dynamic shape to ensure that the
86 /// method signature remains uniform accross different tensor dimensions.
87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
88                        Location loc, ArrayRef<APInt> values) {
89   Type etp = rewriter.getIntegerType(width);
90   unsigned sz = values.size();
91   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
92   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
93   auto elts =
94       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
95   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
96 }
97 
98 /// Returns function reference (first hit also inserts into module).
99 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
100                                  ValueRange operands) {
101   MLIRContext *context = op->getContext();
102   auto module = op->getParentOfType<ModuleOp>();
103   auto func = module.lookupSymbol<FuncOp>(name);
104   if (!func) {
105     OpBuilder moduleBuilder(module.getBodyRegion());
106     moduleBuilder
107         .create<FuncOp>(op->getLoc(), name,
108                         FunctionType::get(context, operands.getTypes(), result))
109         .setPrivate();
110   }
111   return SymbolRefAttr::get(context, name);
112 }
113 
114 /// Generates a call into the "swiss army knife" method of the sparse runtime
115 /// support library for materializing sparse tensors into the computation. The
116 /// method returns the call value and assigns the permutation to 'perm'.
117 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
118                         SparseTensorEncodingAttr &enc, uint32_t action,
119                         Value &perm, Value ptr = Value()) {
120   Location loc = op->getLoc();
121   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
122   SmallVector<Value, 8> params;
123   // Sparsity annotations in tensor constant form.
124   SmallVector<APInt, 4> attrs;
125   unsigned sz = enc.getDimLevelType().size();
126   for (unsigned i = 0; i < sz; i++)
127     attrs.push_back(
128         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
129   params.push_back(getTensor(rewriter, 8, loc, attrs));
130   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
131   // verification of external data, or for construction of internal data.
132   auto shape = resType.getShape();
133   SmallVector<APInt, 4> sizes;
134   for (unsigned i = 0; i < sz; i++) {
135     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
136     sizes.push_back(APInt(64, s));
137   }
138   params.push_back(getTensor(rewriter, 64, loc, sizes));
139   // Dimension order permutation array. This is the "identity" permutation by
140   // default, or otherwise the "reverse" permutation of a given ordering, so
141   // that indices can be mapped quickly to the right position.
142   SmallVector<APInt, 4> rev(sz);
143   if (AffineMap p = enc.getDimOrdering()) {
144     for (unsigned i = 0; i < sz; i++)
145       rev[p.getDimPosition(i)] = APInt(64, i);
146   } else {
147     for (unsigned i = 0; i < sz; i++)
148       rev[i] = APInt(64, i);
149   }
150   perm = getTensor(rewriter, 64, loc, rev);
151   params.push_back(perm);
152   // Secondary and primary types encoding.
153   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
154   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
155   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
156   assert(primary);
157   params.push_back(
158       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
159   params.push_back(
160       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
161   params.push_back(
162       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
163   // User action and pointer.
164   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
165   if (!ptr)
166     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
167   params.push_back(
168       rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
169   params.push_back(ptr);
170   // Generate the call to create new tensor.
171   StringRef name = "newSparseTensor";
172   auto call =
173       rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
174   return call.getResult(0);
175 }
176 
177 /// Generates a call that adds one element to a coordinate scheme.
178 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
179                           Value ptr, Value tensor, Value ind, Value perm,
180                           ValueRange ivs) {
181   Location loc = op->getLoc();
182   StringRef name;
183   Type eltType = tensor.getType().cast<ShapedType>().getElementType();
184   if (eltType.isF64())
185     name = "addEltF64";
186   else if (eltType.isF32())
187     name = "addEltF32";
188   else if (eltType.isInteger(64))
189     name = "addEltI64";
190   else if (eltType.isInteger(32))
191     name = "addEltI32";
192   else if (eltType.isInteger(16))
193     name = "addEltI16";
194   else if (eltType.isInteger(8))
195     name = "addEltI8";
196   else
197     llvm_unreachable("Unknown element type");
198   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
199   // TODO: add if here?
200   unsigned i = 0;
201   for (auto iv : ivs) {
202     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
203     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
204   }
205   SmallVector<Value, 8> params;
206   params.push_back(ptr);
207   params.push_back(val);
208   params.push_back(ind);
209   params.push_back(perm);
210   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
211   rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // Conversion rules.
216 //===----------------------------------------------------------------------===//
217 
218 /// Sparse conversion rule for returns.
219 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
220 public:
221   using OpConversionPattern::OpConversionPattern;
222   LogicalResult
223   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
224                   ConversionPatternRewriter &rewriter) const override {
225     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
226     return success();
227   }
228 };
229 
230 /// Sparse conversion rule for dimension accesses.
231 class SparseTensorToDimSizeConverter
232     : public OpConversionPattern<tensor::DimOp> {
233 public:
234   using OpConversionPattern::OpConversionPattern;
235   LogicalResult
236   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
237                   ConversionPatternRewriter &rewriter) const override {
238     Type resType = op.getType();
239     auto enc = getSparseTensorEncoding(op.source().getType());
240     if (!enc)
241       return failure();
242     // Permute the dim index.
243     Optional<int64_t> index = op.getConstantIndex();
244     if (!index.hasValue())
245       return failure();
246     int64_t idx = index.getValue();
247     if (AffineMap p = enc.getDimOrdering())
248       idx = p.getPermutedPosition(idx);
249     // Generate the call.
250     StringRef name = "sparseDimSize";
251     SmallVector<Value, 2> params;
252     params.push_back(operands[0]);
253     params.push_back(
254         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
255     rewriter.replaceOpWithNewOp<CallOp>(
256         op, resType, getFunc(op, name, resType, params), params);
257     return success();
258   }
259 };
260 
261 /// Sparse conversion rule for the new operator.
262 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
263   using OpConversionPattern::OpConversionPattern;
264   LogicalResult
265   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
266                   ConversionPatternRewriter &rewriter) const override {
267     Type resType = op.getType();
268     auto enc = getSparseTensorEncoding(resType);
269     if (!enc)
270       return failure();
271     Value perm;
272     rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
273     return success();
274   }
275 };
276 
277 /// Sparse conversion rule for the convert operator.
278 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
279   using OpConversionPattern::OpConversionPattern;
280   LogicalResult
281   matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
282                   ConversionPatternRewriter &rewriter) const override {
283     Type resType = op.getType();
284     auto encDst = getSparseTensorEncoding(resType);
285     auto encSrc = getSparseTensorEncoding(op.source().getType());
286     if (encDst && encSrc) {
287       // This is a sparse => sparse conversion, which is handled as follows:
288       //   t = src->asCOO();         ; src to COO in dst order
289       //   dst = newSparseTensor(t)
290       // Using the coordinate scheme as an intermediate does not always
291       // yield the fastest conversion but avoids the need for a full
292       // O(N^2) conversion matrix.
293       Value perm;
294       Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]);
295       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
296       return success();
297     }
298     if (!encDst || encSrc) {
299       // TODO: sparse => dense
300       return failure();
301     }
302     // This is a dense => sparse conversion, which is handled as follows:
303     //   t = newSparseCOO()
304     //   for i1 in dim1
305     //    ..
306     //     for ik in dimk
307     //       val = a[i1,..,ik]
308     //       if val != 0
309     //         t->add(val, [i1,..,ik], [p1,..,pk])
310     //   s = newSparseTensor(t)
311     // Note that the dense tensor traversal code is actually implemented
312     // using MLIR IR to avoid having to expose too much low-level
313     // memref traversal details to the runtime support library.
314     Location loc = op->getLoc();
315     ShapedType shape = resType.cast<ShapedType>();
316     auto memTp =
317         MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
318     Value perm;
319     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
320     Value tensor = operands[0];
321     Value arg = rewriter.create<ConstantOp>(
322         loc, rewriter.getIndexAttr(shape.getRank()));
323     Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
324     SmallVector<Value> lo;
325     SmallVector<Value> hi;
326     SmallVector<Value> st;
327     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
328     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
329     for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
330       lo.push_back(zero);
331       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
332       st.push_back(one);
333     }
334     scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
335                        [&](OpBuilder &builder, Location loc, ValueRange ivs,
336                            ValueRange args) -> scf::ValueVector {
337                          genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
338                                        ivs);
339                          return {};
340                        });
341     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
342     return success();
343   }
344 };
345 
346 /// Sparse conversion rule for pointer accesses.
347 class SparseTensorToPointersConverter
348     : public OpConversionPattern<ToPointersOp> {
349 public:
350   using OpConversionPattern::OpConversionPattern;
351   LogicalResult
352   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
353                   ConversionPatternRewriter &rewriter) const override {
354     Type resType = op.getType();
355     Type eltType = resType.cast<ShapedType>().getElementType();
356     StringRef name;
357     if (eltType.isIndex())
358       name = "sparsePointers";
359     else if (eltType.isInteger(64))
360       name = "sparsePointers64";
361     else if (eltType.isInteger(32))
362       name = "sparsePointers32";
363     else if (eltType.isInteger(16))
364       name = "sparsePointers16";
365     else if (eltType.isInteger(8))
366       name = "sparsePointers8";
367     else
368       return failure();
369     rewriter.replaceOpWithNewOp<CallOp>(
370         op, resType, getFunc(op, name, resType, operands), operands);
371     return success();
372   }
373 };
374 
375 /// Sparse conversion rule for index accesses.
376 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
377 public:
378   using OpConversionPattern::OpConversionPattern;
379   LogicalResult
380   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
381                   ConversionPatternRewriter &rewriter) const override {
382     Type resType = op.getType();
383     Type eltType = resType.cast<ShapedType>().getElementType();
384     StringRef name;
385     if (eltType.isIndex())
386       name = "sparseIndices";
387     else if (eltType.isInteger(64))
388       name = "sparseIndices64";
389     else if (eltType.isInteger(32))
390       name = "sparseIndices32";
391     else if (eltType.isInteger(16))
392       name = "sparseIndices16";
393     else if (eltType.isInteger(8))
394       name = "sparseIndices8";
395     else
396       return failure();
397     rewriter.replaceOpWithNewOp<CallOp>(
398         op, resType, getFunc(op, name, resType, operands), operands);
399     return success();
400   }
401 };
402 
403 /// Sparse conversion rule for value accesses.
404 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
405 public:
406   using OpConversionPattern::OpConversionPattern;
407   LogicalResult
408   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
409                   ConversionPatternRewriter &rewriter) const override {
410     Type resType = op.getType();
411     Type eltType = resType.cast<ShapedType>().getElementType();
412     StringRef name;
413     if (eltType.isF64())
414       name = "sparseValuesF64";
415     else if (eltType.isF32())
416       name = "sparseValuesF32";
417     else if (eltType.isInteger(64))
418       name = "sparseValuesI64";
419     else if (eltType.isInteger(32))
420       name = "sparseValuesI32";
421     else if (eltType.isInteger(16))
422       name = "sparseValuesI16";
423     else if (eltType.isInteger(8))
424       name = "sparseValuesI8";
425     else
426       return failure();
427     rewriter.replaceOpWithNewOp<CallOp>(
428         op, resType, getFunc(op, name, resType, operands), operands);
429     return success();
430   }
431 };
432 
433 /// Sparse conversion rule for tensor reconstruction.
434 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
435 public:
436   using OpConversionPattern::OpConversionPattern;
437   LogicalResult
438   // Simply fold the operator into the pointer to the sparse storage scheme.
439   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
440                   ConversionPatternRewriter &rewriter) const override {
441     // Check that all arguments of the tensor reconstruction operators are calls
442     // into the support library that query exactly the same opaque pointer.
443     Value ptr;
444     for (Value op : operands) {
445       if (auto call = op.getDefiningOp<CallOp>()) {
446         Value arg = call.getOperand(0);
447         if (!arg.getType().isa<LLVM::LLVMPointerType>())
448           return failure();
449         if (!ptr)
450           ptr = arg;
451         else if (arg != ptr)
452           return failure();
453       }
454     }
455     // If a single opaque pointer is found, perform the folding.
456     if (!ptr)
457       return failure();
458     rewriter.replaceOp(op, ptr);
459     return success();
460   }
461 };
462 
463 } // namespace
464 
465 //===----------------------------------------------------------------------===//
466 // Public method for populating conversion rules.
467 //===----------------------------------------------------------------------===//
468 
469 /// Populates the given patterns list with conversion rules required for
470 /// the sparsification of linear algebra operations.
471 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
472                                                   RewritePatternSet &patterns) {
473   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
474                SparseTensorNewConverter, SparseTensorConvertConverter,
475                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
476                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
477       typeConverter, patterns.getContext());
478 }
479