1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Returns integers of given width and values as a constant tensor.
85 /// We cast the static shape into a dynamic shape to ensure that the
86 /// method signature remains uniform across different tensor dimensions.
87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
88                        Location loc, ArrayRef<APInt> values) {
89   Type etp = rewriter.getIntegerType(width);
90   unsigned sz = values.size();
91   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
92   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
93   auto elts =
94       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
95   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
96 }
97 
98 /// Returns a function reference (first hit also inserts into module). Sets
99 /// the "_emit_c_interface" on the function declaration when requested,
100 /// so that LLVM lowering generates a wrapper function that takes care
101 /// of ABI complications with passing in and returning MemRefs to C functions.
102 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType,
103                                  ValueRange operands,
104                                  bool emitCInterface = false) {
105   MLIRContext *context = op->getContext();
106   auto module = op->getParentOfType<ModuleOp>();
107   auto result = SymbolRefAttr::get(context, name);
108   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
109   if (!func) {
110     OpBuilder moduleBuilder(module.getBodyRegion());
111     func = moduleBuilder.create<FuncOp>(
112         op->getLoc(), name,
113         FunctionType::get(context, operands.getTypes(), resultType));
114     func.setPrivate();
115     if (emitCInterface)
116       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
117   }
118   return result;
119 }
120 
121 /// Generates a call into the "swiss army knife" method of the sparse runtime
122 /// support library for materializing sparse tensors into the computation. The
123 /// method returns the call value and assigns the permutation to 'perm'.
124 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
125                         SparseTensorEncodingAttr &enc, uint32_t action,
126                         Value &perm, Value ptr = Value()) {
127   Location loc = op->getLoc();
128   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
129   SmallVector<Value, 8> params;
130   // Sparsity annotations in tensor constant form.
131   SmallVector<APInt, 4> attrs;
132   unsigned sz = enc.getDimLevelType().size();
133   for (unsigned i = 0; i < sz; i++)
134     attrs.push_back(
135         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
136   params.push_back(getTensor(rewriter, 8, loc, attrs));
137   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
138   // verification of external data, or for construction of internal data.
139   auto shape = resType.getShape();
140   SmallVector<APInt, 4> sizes;
141   for (unsigned i = 0; i < sz; i++) {
142     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
143     sizes.push_back(APInt(64, s));
144   }
145   params.push_back(getTensor(rewriter, 64, loc, sizes));
146   // Dimension order permutation array. This is the "identity" permutation by
147   // default, or otherwise the "reverse" permutation of a given ordering, so
148   // that indices can be mapped quickly to the right position.
149   SmallVector<APInt, 4> rev(sz);
150   if (AffineMap p = enc.getDimOrdering()) {
151     for (unsigned i = 0; i < sz; i++)
152       rev[p.getDimPosition(i)] = APInt(64, i);
153   } else {
154     for (unsigned i = 0; i < sz; i++)
155       rev[i] = APInt(64, i);
156   }
157   perm = getTensor(rewriter, 64, loc, rev);
158   params.push_back(perm);
159   // Secondary and primary types encoding.
160   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
161   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
162   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
163   assert(primary);
164   params.push_back(
165       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
166   params.push_back(
167       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
168   params.push_back(
169       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
170   // User action and pointer.
171   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
172   if (!ptr)
173     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
174   params.push_back(
175       rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
176   params.push_back(ptr);
177   // Generate the call to create new tensor.
178   StringRef name = "newSparseTensor";
179   auto call = rewriter.create<CallOp>(
180       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
181       params);
182   return call.getResult(0);
183 }
184 
185 /// Generates a call that adds one element to a coordinate scheme.
186 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
187                           Value ptr, Value tensor, Value ind, Value perm,
188                           ValueRange ivs) {
189   Location loc = op->getLoc();
190   StringRef name;
191   Type eltType = tensor.getType().cast<ShapedType>().getElementType();
192   if (eltType.isF64())
193     name = "addEltF64";
194   else if (eltType.isF32())
195     name = "addEltF32";
196   else if (eltType.isInteger(64))
197     name = "addEltI64";
198   else if (eltType.isInteger(32))
199     name = "addEltI32";
200   else if (eltType.isInteger(16))
201     name = "addEltI16";
202   else if (eltType.isInteger(8))
203     name = "addEltI8";
204   else
205     llvm_unreachable("Unknown element type");
206   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
207   // TODO: add if here?
208   unsigned i = 0;
209   for (auto iv : ivs) {
210     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
211     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
212   }
213   SmallVector<Value, 8> params;
214   params.push_back(ptr);
215   params.push_back(val);
216   params.push_back(ind);
217   params.push_back(perm);
218   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
219   rewriter.create<CallOp>(
220       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
221       params);
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Conversion rules.
226 //===----------------------------------------------------------------------===//
227 
228 /// Sparse conversion rule for returns.
229 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
230 public:
231   using OpConversionPattern::OpConversionPattern;
232   LogicalResult
233   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
234                   ConversionPatternRewriter &rewriter) const override {
235     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
236     return success();
237   }
238 };
239 
240 /// Sparse conversion rule for dimension accesses.
241 class SparseTensorToDimSizeConverter
242     : public OpConversionPattern<tensor::DimOp> {
243 public:
244   using OpConversionPattern::OpConversionPattern;
245   LogicalResult
246   matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
247                   ConversionPatternRewriter &rewriter) const override {
248     Type resType = op.getType();
249     auto enc = getSparseTensorEncoding(op.source().getType());
250     if (!enc)
251       return failure();
252     // Permute the dim index.
253     Optional<int64_t> index = op.getConstantIndex();
254     if (!index.hasValue())
255       return failure();
256     int64_t idx = index.getValue();
257     if (AffineMap p = enc.getDimOrdering())
258       idx = p.getPermutedPosition(idx);
259     // Generate the call.
260     StringRef name = "sparseDimSize";
261     SmallVector<Value, 2> params;
262     params.push_back(operands[0]);
263     params.push_back(
264         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
265     rewriter.replaceOpWithNewOp<CallOp>(
266         op, resType, getFunc(op, name, resType, params), params);
267     return success();
268   }
269 };
270 
271 /// Sparse conversion rule for the new operator.
272 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
273   using OpConversionPattern::OpConversionPattern;
274   LogicalResult
275   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
276                   ConversionPatternRewriter &rewriter) const override {
277     Type resType = op.getType();
278     auto enc = getSparseTensorEncoding(resType);
279     if (!enc)
280       return failure();
281     Value perm;
282     rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
283     return success();
284   }
285 };
286 
287 /// Sparse conversion rule for the convert operator.
288 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
289   using OpConversionPattern::OpConversionPattern;
290   LogicalResult
291   matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     Type resType = op.getType();
294     auto encDst = getSparseTensorEncoding(resType);
295     auto encSrc = getSparseTensorEncoding(op.source().getType());
296     if (encDst && encSrc) {
297       // This is a sparse => sparse conversion, which is handled as follows:
298       //   t = src->asCOO();         ; src to COO in dst order
299       //   dst = newSparseTensor(t)
300       // Using the coordinate scheme as an intermediate does not always
301       // yield the fastest conversion but avoids the need for a full
302       // O(N^2) conversion matrix.
303       Value perm;
304       Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]);
305       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
306       return success();
307     }
308     if (!encDst || encSrc) {
309       // TODO: sparse => dense
310       return failure();
311     }
312     // This is a dense => sparse conversion, which is handled as follows:
313     //   t = newSparseCOO()
314     //   for i1 in dim1
315     //    ..
316     //     for ik in dimk
317     //       val = a[i1,..,ik]
318     //       if val != 0
319     //         t->add(val, [i1,..,ik], [p1,..,pk])
320     //   s = newSparseTensor(t)
321     // Note that the dense tensor traversal code is actually implemented
322     // using MLIR IR to avoid having to expose too much low-level
323     // memref traversal details to the runtime support library.
324     Location loc = op->getLoc();
325     ShapedType shape = resType.cast<ShapedType>();
326     auto memTp =
327         MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
328     Value perm;
329     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
330     Value tensor = operands[0];
331     Value arg = rewriter.create<ConstantOp>(
332         loc, rewriter.getIndexAttr(shape.getRank()));
333     Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
334     SmallVector<Value> lo;
335     SmallVector<Value> hi;
336     SmallVector<Value> st;
337     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
338     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
339     for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
340       lo.push_back(zero);
341       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
342       st.push_back(one);
343     }
344     scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
345                        [&](OpBuilder &builder, Location loc, ValueRange ivs,
346                            ValueRange args) -> scf::ValueVector {
347                          genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
348                                        ivs);
349                          return {};
350                        });
351     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
352     return success();
353   }
354 };
355 
356 /// Sparse conversion rule for pointer accesses.
357 class SparseTensorToPointersConverter
358     : public OpConversionPattern<ToPointersOp> {
359 public:
360   using OpConversionPattern::OpConversionPattern;
361   LogicalResult
362   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
363                   ConversionPatternRewriter &rewriter) const override {
364     Type resType = op.getType();
365     Type eltType = resType.cast<ShapedType>().getElementType();
366     StringRef name;
367     if (eltType.isIndex())
368       name = "sparsePointers";
369     else if (eltType.isInteger(64))
370       name = "sparsePointers64";
371     else if (eltType.isInteger(32))
372       name = "sparsePointers32";
373     else if (eltType.isInteger(16))
374       name = "sparsePointers16";
375     else if (eltType.isInteger(8))
376       name = "sparsePointers8";
377     else
378       return failure();
379     rewriter.replaceOpWithNewOp<CallOp>(
380         op, resType,
381         getFunc(op, name, resType, operands, /*emitCInterface=*/true),
382         operands);
383     return success();
384   }
385 };
386 
387 /// Sparse conversion rule for index accesses.
388 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
389 public:
390   using OpConversionPattern::OpConversionPattern;
391   LogicalResult
392   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
393                   ConversionPatternRewriter &rewriter) const override {
394     Type resType = op.getType();
395     Type eltType = resType.cast<ShapedType>().getElementType();
396     StringRef name;
397     if (eltType.isIndex())
398       name = "sparseIndices";
399     else if (eltType.isInteger(64))
400       name = "sparseIndices64";
401     else if (eltType.isInteger(32))
402       name = "sparseIndices32";
403     else if (eltType.isInteger(16))
404       name = "sparseIndices16";
405     else if (eltType.isInteger(8))
406       name = "sparseIndices8";
407     else
408       return failure();
409     rewriter.replaceOpWithNewOp<CallOp>(
410         op, resType,
411         getFunc(op, name, resType, operands, /*emitCInterface=*/true),
412         operands);
413     return success();
414   }
415 };
416 
417 /// Sparse conversion rule for value accesses.
418 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
419 public:
420   using OpConversionPattern::OpConversionPattern;
421   LogicalResult
422   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
423                   ConversionPatternRewriter &rewriter) const override {
424     Type resType = op.getType();
425     Type eltType = resType.cast<ShapedType>().getElementType();
426     StringRef name;
427     if (eltType.isF64())
428       name = "sparseValuesF64";
429     else if (eltType.isF32())
430       name = "sparseValuesF32";
431     else if (eltType.isInteger(64))
432       name = "sparseValuesI64";
433     else if (eltType.isInteger(32))
434       name = "sparseValuesI32";
435     else if (eltType.isInteger(16))
436       name = "sparseValuesI16";
437     else if (eltType.isInteger(8))
438       name = "sparseValuesI8";
439     else
440       return failure();
441     rewriter.replaceOpWithNewOp<CallOp>(
442         op, resType,
443         getFunc(op, name, resType, operands, /*emitCInterface=*/true),
444         operands);
445     return success();
446   }
447 };
448 
449 /// Sparse conversion rule for tensor reconstruction.
450 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
451 public:
452   using OpConversionPattern::OpConversionPattern;
453   LogicalResult
454   // Simply fold the operator into the pointer to the sparse storage scheme.
455   matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
456                   ConversionPatternRewriter &rewriter) const override {
457     // Check that all arguments of the tensor reconstruction operators are calls
458     // into the support library that query exactly the same opaque pointer.
459     Value ptr;
460     for (Value op : operands) {
461       if (auto call = op.getDefiningOp<CallOp>()) {
462         Value arg = call.getOperand(0);
463         if (!arg.getType().isa<LLVM::LLVMPointerType>())
464           return failure();
465         if (!ptr)
466           ptr = arg;
467         else if (arg != ptr)
468           return failure();
469       }
470     }
471     // If a single opaque pointer is found, perform the folding.
472     if (!ptr)
473       return failure();
474     rewriter.replaceOp(op, ptr);
475     return success();
476   }
477 };
478 
479 } // namespace
480 
481 //===----------------------------------------------------------------------===//
482 // Public method for populating conversion rules.
483 //===----------------------------------------------------------------------===//
484 
485 /// Populates the given patterns list with conversion rules required for
486 /// the sparsification of linear algebra operations.
487 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
488                                                   RewritePatternSet &patterns) {
489   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
490                SparseTensorNewConverter, SparseTensorConvertConverter,
491                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
492                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
493       typeConverter, patterns.getContext());
494 }
495