1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Returns integers of given width and values as a constant tensor.
85 /// We cast the static shape into a dynamic shape to ensure that the
86 /// method signature remains uniform across different tensor dimensions.
87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
88                        Location loc, ArrayRef<APInt> values) {
89   Type etp = rewriter.getIntegerType(width);
90   unsigned sz = values.size();
91   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
92   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
93   auto elts =
94       rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
95   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
96 }
97 
98 /// Returns a function reference (first hit also inserts into module). Sets
99 /// the "_emit_c_interface" on the function declaration when requested,
100 /// so that LLVM lowering generates a wrapper function that takes care
101 /// of ABI complications with passing in and returning MemRefs to C functions.
102 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
103                                  TypeRange resultType, ValueRange operands,
104                                  bool emitCInterface = false) {
105   MLIRContext *context = op->getContext();
106   auto module = op->getParentOfType<ModuleOp>();
107   auto result = SymbolRefAttr::get(context, name);
108   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
109   if (!func) {
110     OpBuilder moduleBuilder(module.getBodyRegion());
111     func = moduleBuilder.create<FuncOp>(
112         op->getLoc(), name,
113         FunctionType::get(context, operands.getTypes(), resultType));
114     func.setPrivate();
115     if (emitCInterface)
116       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
117   }
118   return result;
119 }
120 
121 /// Generates a call into the "swiss army knife" method of the sparse runtime
122 /// support library for materializing sparse tensors into the computation. The
123 /// method returns the call value and assigns the permutation to 'perm'.
124 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
125                         SparseTensorEncodingAttr &enc, uint32_t action,
126                         Value &perm, Value ptr = Value()) {
127   Location loc = op->getLoc();
128   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
129   SmallVector<Value, 8> params;
130   // Sparsity annotations in tensor constant form.
131   SmallVector<APInt, 4> attrs;
132   unsigned sz = enc.getDimLevelType().size();
133   for (unsigned i = 0; i < sz; i++)
134     attrs.push_back(
135         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
136   params.push_back(getTensor(rewriter, 8, loc, attrs));
137   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
138   // verification of external data, or for construction of internal data.
139   auto shape = resType.getShape();
140   SmallVector<APInt, 4> sizes;
141   for (unsigned i = 0; i < sz; i++) {
142     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
143     sizes.push_back(APInt(64, s));
144   }
145   params.push_back(getTensor(rewriter, 64, loc, sizes));
146   // Dimension order permutation array. This is the "identity" permutation by
147   // default, or otherwise the "reverse" permutation of a given ordering, so
148   // that indices can be mapped quickly to the right position.
149   SmallVector<APInt, 4> rev(sz);
150   if (AffineMap p = enc.getDimOrdering()) {
151     for (unsigned i = 0; i < sz; i++)
152       rev[p.getDimPosition(i)] = APInt(64, i);
153   } else {
154     for (unsigned i = 0; i < sz; i++)
155       rev[i] = APInt(64, i);
156   }
157   perm = getTensor(rewriter, 64, loc, rev);
158   params.push_back(perm);
159   // Secondary and primary types encoding.
160   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
161   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
162   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
163   assert(primary);
164   params.push_back(
165       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
166   params.push_back(
167       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
168   params.push_back(
169       rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
170   // User action and pointer.
171   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
172   if (!ptr)
173     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
174   params.push_back(
175       rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action)));
176   params.push_back(ptr);
177   // Generate the call to create new tensor.
178   StringRef name = "newSparseTensor";
179   auto call = rewriter.create<CallOp>(
180       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
181       params);
182   return call.getResult(0);
183 }
184 
185 /// Generates a constant zero of the given type.
186 static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
187                      Type t) {
188   return rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
189 }
190 
191 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
192 /// For floating types, we use the "unordered" comparator (i.e., returns
193 /// true if `v` is NaN).
194 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
195                           Value v) {
196   Type t = v.getType();
197   Value zero = getZero(rewriter, loc, t);
198   if (t.isa<FloatType>())
199     return rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, v, zero);
200   if (t.isIntOrIndex())
201     return rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, v, zero);
202   llvm_unreachable("Unknown element type");
203 }
204 
205 /// Generates the code to read the value from tensor[ivs], and conditionally
206 /// stores the indices ivs to the memory in ind. The generated code looks like
207 /// the following and the insertion point after this routine is inside the
208 /// if-then branch behind the assignment to ind. This is to ensure that the
209 /// addEltX call generated after is inside the if-then branch.
210 ///    if (tensor[ivs]!=0) {
211 ///      ind = ivs
212 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
213                                       Operation *op, Value tensor, Value ind,
214                                       ValueRange ivs) {
215   Location loc = op->getLoc();
216   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
217   Value cond = genIsNonzero(rewriter, loc, val);
218   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
219   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
220   unsigned i = 0;
221   for (auto iv : ivs) {
222     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
223     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
224   }
225   return val;
226 }
227 
228 /// Generates a call that adds one element to a coordinate scheme.
229 /// In particular, this generates code like the following:
230 ///   val = a[i1,..,ik];
231 ///   if val != 0
232 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
233 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
234                           Type eltType, Value ptr, Value val, Value ind,
235                           Value perm) {
236   Location loc = op->getLoc();
237   StringRef name;
238   if (eltType.isF64())
239     name = "addEltF64";
240   else if (eltType.isF32())
241     name = "addEltF32";
242   else if (eltType.isInteger(64))
243     name = "addEltI64";
244   else if (eltType.isInteger(32))
245     name = "addEltI32";
246   else if (eltType.isInteger(16))
247     name = "addEltI16";
248   else if (eltType.isInteger(8))
249     name = "addEltI8";
250   else
251     llvm_unreachable("Unknown element type");
252   SmallVector<Value, 8> params;
253   params.push_back(ptr);
254   params.push_back(val);
255   params.push_back(ind);
256   params.push_back(perm);
257   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
258   rewriter.create<CallOp>(
259       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
260       params);
261 }
262 
263 /// If the tensor is a sparse constant, generates and returns the pair of
264 /// the constants for the indices and the values.
265 static Optional<std::pair<Value, Value>>
266 genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
267                        Value tensor) {
268   if (auto constOp = tensor.getDefiningOp<ConstantOp>()) {
269     if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
270       Location loc = op->getLoc();
271       DenseElementsAttr indicesAttr = attr.getIndices();
272       Value indices = rewriter.create<ConstantOp>(loc, indicesAttr);
273       DenseElementsAttr valuesAttr = attr.getValues();
274       Value values = rewriter.create<ConstantOp>(loc, valuesAttr);
275       return std::make_pair(indices, values);
276     }
277   }
278   return {};
279 }
280 
281 /// Generates the code to copy the index at indices[ivs] to ind, and return
282 /// the value at value[ivs].
283 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
284                                        Operation *op, Value indices,
285                                        Value values, Value ind, ValueRange ivs,
286                                        unsigned rank) {
287   Location loc = op->getLoc();
288   for (unsigned i = 0; i < rank; i++) {
289     Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i));
290     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
291                                                    ValueRange{ivs[0], idx});
292     val = rewriter.create<IndexCastOp>(loc, val, rewriter.getIndexType());
293     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
294   }
295   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
296 }
297 
298 /// Generates code to stack-allocate a `memref<?xindex>` where the `?`
299 /// is the given `rank`.  This array is intended to serve as a reusable
300 /// buffer for storing the indices of a single tensor element, to avoid
301 /// allocation in the body of loops.
302 static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
303                            int64_t rank) {
304   auto indexTp = rewriter.getIndexType();
305   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
306   Value arg = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(rank));
307   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // Conversion rules.
312 //===----------------------------------------------------------------------===//
313 
314 /// Sparse conversion rule for returns.
315 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
316 public:
317   using OpConversionPattern::OpConversionPattern;
318   LogicalResult
319   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
320                   ConversionPatternRewriter &rewriter) const override {
321     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
322     return success();
323   }
324 };
325 
326 /// Sparse conversion rule for dimension accesses.
327 class SparseTensorToDimSizeConverter
328     : public OpConversionPattern<tensor::DimOp> {
329 public:
330   using OpConversionPattern::OpConversionPattern;
331   LogicalResult
332   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
333                   ConversionPatternRewriter &rewriter) const override {
334     Type resType = op.getType();
335     auto enc = getSparseTensorEncoding(op.source().getType());
336     if (!enc)
337       return failure();
338     // Permute the dim index.
339     Optional<int64_t> index = op.getConstantIndex();
340     if (!index.hasValue())
341       return failure();
342     int64_t idx = index.getValue();
343     if (AffineMap p = enc.getDimOrdering())
344       idx = p.getPermutedPosition(idx);
345     // Generate the call.
346     StringRef name = "sparseDimSize";
347     SmallVector<Value, 2> params;
348     params.push_back(adaptor.getOperands()[0]);
349     params.push_back(
350         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
351     rewriter.replaceOpWithNewOp<CallOp>(
352         op, resType, getFunc(op, name, resType, params), params);
353     return success();
354   }
355 };
356 
357 /// Sparse conversion rule for the new operator.
358 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
359   using OpConversionPattern::OpConversionPattern;
360   LogicalResult
361   matchAndRewrite(NewOp op, OpAdaptor adaptor,
362                   ConversionPatternRewriter &rewriter) const override {
363     Type resType = op.getType();
364     auto enc = getSparseTensorEncoding(resType);
365     if (!enc)
366       return failure();
367     Value perm;
368     rewriter.replaceOp(
369         op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
370     return success();
371   }
372 };
373 
374 /// Sparse conversion rule for the convert operator.
375 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
376   using OpConversionPattern::OpConversionPattern;
377   LogicalResult
378   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
379                   ConversionPatternRewriter &rewriter) const override {
380     Type resType = op.getType();
381     auto encDst = getSparseTensorEncoding(resType);
382     auto encSrc = getSparseTensorEncoding(op.source().getType());
383     auto src = adaptor.getOperands()[0];
384     if (encDst && encSrc) {
385       // This is a sparse => sparse conversion, which is handled as follows:
386       //   t = src->toCOO();         ; src to COO in dst order
387       //   dst = newSparseTensor(t)
388       // Using the coordinate scheme as an intermediate does not always
389       // yield the fastest conversion but avoids the need for a full
390       // O(N^2) conversion matrix.
391       Value perm;
392       Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
393       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
394       return success();
395     }
396     if (!encDst || encSrc) {
397       // TODO: sparse => dense
398       return failure();
399     }
400     // This is a dense => sparse conversion or a sparse constant in COO =>
401     // sparse conversion, which is handled as follows:
402     //   t = newSparseCOO()
403     //   ...code to fill the COO tensor t...
404     //   s = newSparseTensor(t)
405     //
406     // To fill the COO tensor from a dense tensor:
407     //   for i1 in dim1
408     //    ..
409     //     for ik in dimk
410     //       val = a[i1,..,ik]
411     //       if val != 0
412     //         t->add(val, [i1,..,ik], [p1,..,pk])
413     //
414     // To fill the COO tensor from a sparse constant in COO format:
415     //   for i in range(NNZ)
416     //     val = values[i]
417     //     [i1,..,ik] = indices[i]
418     //     t->add(val, [i1,..,ik], [p1,..,pk])
419     //
420     // Note that the dense tensor traversal code is actually implemented
421     // using MLIR IR to avoid having to expose too much low-level
422     // memref traversal details to the runtime support library.
423     // Also note that the code below only generates the "new" ops and
424     // the loop-nest per se; whereas the entire body of the innermost
425     // loop is generated by genAddElt().
426     Location loc = op->getLoc();
427     ShapedType shape = resType.cast<ShapedType>();
428     Value perm;
429     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
430     Value ind = allocaIndices(rewriter, loc, shape.getRank());
431     SmallVector<Value> lo;
432     SmallVector<Value> hi;
433     SmallVector<Value> st;
434     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
435     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
436     auto indicesValues = genSplitSparseConstant(rewriter, op, src);
437     bool isCOOConstant = indicesValues.hasValue();
438     Value indices;
439     Value values;
440     if (isCOOConstant) {
441       indices = indicesValues->first;
442       values = indicesValues->second;
443       lo.push_back(zero);
444       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
445       st.push_back(one);
446     } else {
447       for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
448         lo.push_back(zero);
449         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
450         st.push_back(one);
451       }
452     }
453     Type eltType = shape.getElementType();
454     unsigned rank = shape.getRank();
455     scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
456                        [&](OpBuilder &builder, Location loc, ValueRange ivs,
457                            ValueRange args) -> scf::ValueVector {
458                          Value val;
459                          if (isCOOConstant)
460                            val = genIndexAndValueForSparse(
461                                rewriter, op, indices, values, ind, ivs, rank);
462                          else
463                            val = genIndexAndValueForDense(rewriter, op, src,
464                                                           ind, ivs);
465                          genAddEltCall(rewriter, op, eltType, ptr, val, ind,
466                                        perm);
467                          return {};
468                        });
469     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
470     return success();
471   }
472 };
473 
474 /// Sparse conversion rule for the release operator.
475 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
476 public:
477   using OpConversionPattern::OpConversionPattern;
478   LogicalResult
479   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
480                   ConversionPatternRewriter &rewriter) const override {
481     StringRef name = "delSparseTensor";
482     TypeRange none;
483     rewriter.create<CallOp>(op.getLoc(), none,
484                             getFunc(op, name, none, adaptor.getOperands()),
485                             adaptor.getOperands());
486     rewriter.eraseOp(op);
487     return success();
488   }
489 };
490 
491 /// Sparse conversion rule for pointer accesses.
492 class SparseTensorToPointersConverter
493     : public OpConversionPattern<ToPointersOp> {
494 public:
495   using OpConversionPattern::OpConversionPattern;
496   LogicalResult
497   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
498                   ConversionPatternRewriter &rewriter) const override {
499     Type resType = op.getType();
500     Type eltType = resType.cast<ShapedType>().getElementType();
501     StringRef name;
502     if (eltType.isIndex())
503       name = "sparsePointers"; // 64-bit, but its own name for unique signature
504     else if (eltType.isInteger(64))
505       name = "sparsePointers64";
506     else if (eltType.isInteger(32))
507       name = "sparsePointers32";
508     else if (eltType.isInteger(16))
509       name = "sparsePointers16";
510     else if (eltType.isInteger(8))
511       name = "sparsePointers8";
512     else
513       return failure();
514     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
515                                         getFunc(op, name, resType,
516                                                 adaptor.getOperands(),
517                                                 /*emitCInterface=*/true),
518                                         adaptor.getOperands());
519     return success();
520   }
521 };
522 
523 /// Sparse conversion rule for index accesses.
524 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
525 public:
526   using OpConversionPattern::OpConversionPattern;
527   LogicalResult
528   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
529                   ConversionPatternRewriter &rewriter) const override {
530     Type resType = op.getType();
531     Type eltType = resType.cast<ShapedType>().getElementType();
532     StringRef name;
533     if (eltType.isIndex())
534       name = "sparseIndices"; // 64-bit, but its own name for unique signature
535     else if (eltType.isInteger(64))
536       name = "sparseIndices64";
537     else if (eltType.isInteger(32))
538       name = "sparseIndices32";
539     else if (eltType.isInteger(16))
540       name = "sparseIndices16";
541     else if (eltType.isInteger(8))
542       name = "sparseIndices8";
543     else
544       return failure();
545     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
546                                         getFunc(op, name, resType,
547                                                 adaptor.getOperands(),
548                                                 /*emitCInterface=*/true),
549                                         adaptor.getOperands());
550     return success();
551   }
552 };
553 
554 /// Sparse conversion rule for value accesses.
555 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
556 public:
557   using OpConversionPattern::OpConversionPattern;
558   LogicalResult
559   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
560                   ConversionPatternRewriter &rewriter) const override {
561     Type resType = op.getType();
562     Type eltType = resType.cast<ShapedType>().getElementType();
563     StringRef name;
564     if (eltType.isF64())
565       name = "sparseValuesF64";
566     else if (eltType.isF32())
567       name = "sparseValuesF32";
568     else if (eltType.isInteger(64))
569       name = "sparseValuesI64";
570     else if (eltType.isInteger(32))
571       name = "sparseValuesI32";
572     else if (eltType.isInteger(16))
573       name = "sparseValuesI16";
574     else if (eltType.isInteger(8))
575       name = "sparseValuesI8";
576     else
577       return failure();
578     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
579                                         getFunc(op, name, resType,
580                                                 adaptor.getOperands(),
581                                                 /*emitCInterface=*/true),
582                                         adaptor.getOperands());
583     return success();
584   }
585 };
586 
587 /// Sparse conversion rule for tensor reconstruction.
588 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
589 public:
590   using OpConversionPattern::OpConversionPattern;
591   LogicalResult
592   // Simply fold the operator into the pointer to the sparse storage scheme.
593   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
594                   ConversionPatternRewriter &rewriter) const override {
595     // Check that all arguments of the tensor reconstruction operators are calls
596     // into the support library that query exactly the same opaque pointer.
597     Value ptr;
598     for (Value op : adaptor.getOperands()) {
599       if (auto call = op.getDefiningOp<CallOp>()) {
600         Value arg = call.getOperand(0);
601         if (!arg.getType().isa<LLVM::LLVMPointerType>())
602           return failure();
603         if (!ptr)
604           ptr = arg;
605         else if (arg != ptr)
606           return failure();
607       }
608     }
609     // If a single opaque pointer is found, perform the folding.
610     if (!ptr)
611       return failure();
612     rewriter.replaceOp(op, ptr);
613     return success();
614   }
615 };
616 
617 } // namespace
618 
619 //===----------------------------------------------------------------------===//
620 // Public method for populating conversion rules.
621 //===----------------------------------------------------------------------===//
622 
623 /// Populates the given patterns list with conversion rules required for
624 /// the sparsification of linear algebra operations.
625 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
626                                                   RewritePatternSet &patterns) {
627   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
628                SparseTensorNewConverter, SparseTensorConvertConverter,
629                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
630                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
631                SparseTensorToTensorConverter>(typeConverter,
632                                               patterns.getContext());
633 }
634