1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Returns integers of given width and values as a constant tensor.
85 /// We cast the static shape into a dynamic shape to ensure that the
86 /// method signature remains uniform accross different tensor dimensions.
87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
88                        Location loc, ArrayRef<APInt> values) {
89   Type etp = rewriter.getIntegerType(width);
90   unsigned sz = values.size();
91   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
92   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
93   auto elts =
94       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
95   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
96 }
97 
98 /// Returns function reference (first hit also inserts into module).
99 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
100                                  ValueRange operands) {
101   MLIRContext *context = op->getContext();
102   auto module = op->getParentOfType<ModuleOp>();
103   auto func = module.lookupSymbol<FuncOp>(name);
104   if (!func) {
105     OpBuilder moduleBuilder(module.getBodyRegion());
106     moduleBuilder
107         .create<FuncOp>(op->getLoc(), name,
108                         FunctionType::get(context, operands.getTypes(), result))
109         .setPrivate();
110   }
111   return SymbolRefAttr::get(context, name);
112 }
113 
114 /// Generates a call into the "swiss army knife" method of the sparse runtime
115 /// support library for materializing sparse tensors into the computation. The
116 /// method returns the call value and assigns the permutation to 'perm'.
117 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
118                         SparseTensorEncodingAttr &enc, uint32_t action,
119                         Value &perm, Value ptr = Value()) {
120   Location loc = op->getLoc();
121   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
122   SmallVector<Value, 8> params;
123   // Sparsity annotations in tensor constant form.
124   SmallVector<APInt, 4> attrs;
125   unsigned sz = enc.getDimLevelType().size();
126   for (unsigned i = 0; i < sz; i++)
127     attrs.push_back(
128         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
129   params.push_back(getTensor(rewriter, 8, loc, attrs));
130   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
131   // verification of external data, or for construction of internal data.
132   auto shape = resType.getShape();
133   SmallVector<APInt, 4> sizes;
134   for (unsigned i = 0; i < sz; i++) {
135     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
136     sizes.push_back(APInt(64, s));
137   }
138   params.push_back(getTensor(rewriter, 64, loc, sizes));
139   // Dimension order permutation array. This is the "identity" permutation by
140   // default, or otherwise the "reverse" permutation of a given ordering, so
141   // that indices can be mapped quickly to the right position.
142   SmallVector<APInt, 4> rev(sz);
143   if (AffineMap p = enc.getDimOrdering()) {
144     for (unsigned i = 0; i < sz; i++)
145       rev[p.getDimPosition(i)] = APInt(64, i);
146   } else {
147     for (unsigned i = 0; i < sz; i++)
148       rev[i] = APInt(64, i);
149   }
150   perm = getTensor(rewriter, 64, loc, rev);
151   params.push_back(perm);
152   // Secondary and primary types encoding.
153   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
154   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
155   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
156   assert(primary);
157   params.push_back(
158       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
159   params.push_back(
160       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
161   params.push_back(
162       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
163   // User action and pointer.
164   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
165   if (!ptr)
166     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
167   params.push_back(
168       rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
169   params.push_back(ptr);
170   // Generate the call to create new tensor.
171   StringRef name = "newSparseTensor";
172   auto call =
173       rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
174   return call.getResult(0);
175 }
176 
177 /// Generates a call that adds one element to a coordinate scheme.
178 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
179                           Value ptr, Value tensor, Value ind, Value perm,
180                           ValueRange ivs) {
181   Location loc = op->getLoc();
182   StringRef name;
183   Type eltType = tensor.getType().cast<ShapedType>().getElementType();
184   if (eltType.isF64())
185     name = "addEltF64";
186   else if (eltType.isF32())
187     name = "addEltF32";
188   else if (eltType.isInteger(64))
189     name = "addEltI64";
190   else if (eltType.isInteger(32))
191     name = "addEltI32";
192   else if (eltType.isInteger(16))
193     name = "addEltI16";
194   else if (eltType.isInteger(8))
195     name = "addEltI8";
196   else
197     llvm_unreachable("Unknown element type");
198   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
199   // TODO: add if here?
200   unsigned i = 0;
201   for (auto iv : ivs) {
202     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
203     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
204   }
205   SmallVector<Value, 8> params;
206   params.push_back(ptr);
207   params.push_back(val);
208   params.push_back(ind);
209   params.push_back(perm);
210   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
211   rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params);
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // Conversion rules.
216 //===----------------------------------------------------------------------===//
217 
218 /// Sparse conversion rule for returns.
219 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
220 public:
221   using OpConversionPattern::OpConversionPattern;
222   LogicalResult
223   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
224                   ConversionPatternRewriter &rewriter) const override {
225     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
226     return success();
227   }
228 };
229 
230 /// Sparse conversion rule for dimension accesses.
231 class SparseTensorToDimSizeConverter
232     : public OpConversionPattern<tensor::DimOp> {
233 public:
234   using OpConversionPattern::OpConversionPattern;
235   LogicalResult
236   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
237                   ConversionPatternRewriter &rewriter) const override {
238     Type resType = op.getType();
239     auto enc = getSparseTensorEncoding(op.source().getType());
240     if (!enc)
241       return failure();
242     // Permute the dim index.
243     Optional<int64_t> index = op.getConstantIndex();
244     if (!index.hasValue())
245       return failure();
246     int64_t idx = index.getValue();
247     if (AffineMap p = enc.getDimOrdering())
248       idx = p.getPermutedPosition(idx);
249     // Generate the call.
250     StringRef name = "sparseDimSize";
251     SmallVector<Value, 2> params;
252     params.push_back(operands[0]);
253     params.push_back(
254         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
255     rewriter.replaceOpWithNewOp<CallOp>(
256         op, resType, getFunc(op, name, resType, params), params);
257     return success();
258   }
259 };
260 
261 /// Sparse conversion rule for the new operator.
262 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
263   using OpConversionPattern::OpConversionPattern;
264   LogicalResult
265   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
266                   ConversionPatternRewriter &rewriter) const override {
267     Type resType = op.getType();
268     auto enc = getSparseTensorEncoding(resType);
269     if (!enc)
270       return failure();
271     Value perm;
272     rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
273     return success();
274   }
275 };
276 
277 /// Sparse conversion rule for the convert operator.
278 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
279   using OpConversionPattern::OpConversionPattern;
280   LogicalResult
281   matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
282                   ConversionPatternRewriter &rewriter) const override {
283     Type resType = op.getType();
284     auto encDst = getSparseTensorEncoding(resType);
285     auto encSrc = getSparseTensorEncoding(op.source().getType());
286     // TODO: implement sparse => sparse
287     //             and sparse => dense
288     if (!encDst || encSrc)
289       return failure();
290     // This is a dense => sparse conversion, that is handled as follows:
291     //   t = newSparseCOO()
292     //   for i1 in dim1
293     //    ..
294     //     for ik in dimk
295     //       val = a[i1,..,ik]
296     //       if val != 0
297     //         t->add(val, [i1,..,ik], [p1,..,pk])
298     //   s = newSparseTensor(t)
299     // Note that the dense tensor traversal code is actually implemented
300     // using MLIR IR to avoid having to expose too much low-level
301     // memref traversal details to the runtime support library.
302     Location loc = op->getLoc();
303     ShapedType shape = resType.cast<ShapedType>();
304     auto memTp =
305         MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
306     Value perm;
307     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
308     Value tensor = operands[0];
309     Value arg = rewriter.create<ConstantOp>(
310         loc, rewriter.getIndexAttr(shape.getRank()));
311     Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
312     SmallVector<Value> lo;
313     SmallVector<Value> hi;
314     SmallVector<Value> st;
315     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
316     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
317     for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
318       lo.push_back(zero);
319       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
320       st.push_back(one);
321     }
322     scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
323                        [&](OpBuilder &builder, Location loc, ValueRange ivs,
324                            ValueRange args) -> scf::ValueVector {
325                          genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
326                                        ivs);
327                          return {};
328                        });
329     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
330     return success();
331   }
332 };
333 
334 /// Sparse conversion rule for pointer accesses.
335 class SparseTensorToPointersConverter
336     : public OpConversionPattern<ToPointersOp> {
337 public:
338   using OpConversionPattern::OpConversionPattern;
339   LogicalResult
340   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
341                   ConversionPatternRewriter &rewriter) const override {
342     Type resType = op.getType();
343     Type eltType = resType.cast<ShapedType>().getElementType();
344     StringRef name;
345     if (eltType.isIndex())
346       name = "sparsePointers";
347     else if (eltType.isInteger(64))
348       name = "sparsePointers64";
349     else if (eltType.isInteger(32))
350       name = "sparsePointers32";
351     else if (eltType.isInteger(16))
352       name = "sparsePointers16";
353     else if (eltType.isInteger(8))
354       name = "sparsePointers8";
355     else
356       return failure();
357     rewriter.replaceOpWithNewOp<CallOp>(
358         op, resType, getFunc(op, name, resType, operands), operands);
359     return success();
360   }
361 };
362 
363 /// Sparse conversion rule for index accesses.
364 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
365 public:
366   using OpConversionPattern::OpConversionPattern;
367   LogicalResult
368   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
369                   ConversionPatternRewriter &rewriter) const override {
370     Type resType = op.getType();
371     Type eltType = resType.cast<ShapedType>().getElementType();
372     StringRef name;
373     if (eltType.isIndex())
374       name = "sparseIndices";
375     else if (eltType.isInteger(64))
376       name = "sparseIndices64";
377     else if (eltType.isInteger(32))
378       name = "sparseIndices32";
379     else if (eltType.isInteger(16))
380       name = "sparseIndices16";
381     else if (eltType.isInteger(8))
382       name = "sparseIndices8";
383     else
384       return failure();
385     rewriter.replaceOpWithNewOp<CallOp>(
386         op, resType, getFunc(op, name, resType, operands), operands);
387     return success();
388   }
389 };
390 
391 /// Sparse conversion rule for value accesses.
392 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
393 public:
394   using OpConversionPattern::OpConversionPattern;
395   LogicalResult
396   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
397                   ConversionPatternRewriter &rewriter) const override {
398     Type resType = op.getType();
399     Type eltType = resType.cast<ShapedType>().getElementType();
400     StringRef name;
401     if (eltType.isF64())
402       name = "sparseValuesF64";
403     else if (eltType.isF32())
404       name = "sparseValuesF32";
405     else if (eltType.isInteger(64))
406       name = "sparseValuesI64";
407     else if (eltType.isInteger(32))
408       name = "sparseValuesI32";
409     else if (eltType.isInteger(16))
410       name = "sparseValuesI16";
411     else if (eltType.isInteger(8))
412       name = "sparseValuesI8";
413     else
414       return failure();
415     rewriter.replaceOpWithNewOp<CallOp>(
416         op, resType, getFunc(op, name, resType, operands), operands);
417     return success();
418   }
419 };
420 
421 /// Sparse conversion rule for tensor reconstruction.
422 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
423 public:
424   using OpConversionPattern::OpConversionPattern;
425   LogicalResult
426   // Simply fold the operator into the pointer to the sparse storage scheme.
427   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
428                   ConversionPatternRewriter &rewriter) const override {
429     // Check that all arguments of the tensor reconstruction operators are calls
430     // into the support library that query exactly the same opaque pointer.
431     Value ptr;
432     for (Value op : operands) {
433       if (auto call = op.getDefiningOp<CallOp>()) {
434         Value arg = call.getOperand(0);
435         if (!arg.getType().isa<LLVM::LLVMPointerType>())
436           return failure();
437         if (!ptr)
438           ptr = arg;
439         else if (arg != ptr)
440           return failure();
441       }
442     }
443     // If a single opaque pointer is found, perform the folding.
444     if (!ptr)
445       return failure();
446     rewriter.replaceOp(op, ptr);
447     return success();
448   }
449 };
450 
451 } // namespace
452 
453 //===----------------------------------------------------------------------===//
454 // Public method for populating conversion rules.
455 //===----------------------------------------------------------------------===//
456 
457 /// Populates the given patterns list with conversion rules required for
458 /// the sparsification of linear algebra operations.
459 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
460                                                   RewritePatternSet &patterns) {
461   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
462                SparseTensorNewConverter, SparseTensorConvertConverter,
463                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
464                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
465       typeConverter, patterns.getContext());
466 }
467