1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Returns integers of given width and values as a constant tensor.
85 /// We cast the static shape into a dynamic shape to ensure that the
86 /// method signature remains uniform across different tensor dimensions.
87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
88                        Location loc, ArrayRef<APInt> values) {
89   Type etp = rewriter.getIntegerType(width);
90   unsigned sz = values.size();
91   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
92   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
93   auto elts = rewriter.create<arith::ConstantOp>(
94       loc, DenseElementsAttr::get(tt1, values));
95   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
96 }
97 
98 /// Returns a function reference (first hit also inserts into module). Sets
99 /// the "_emit_c_interface" on the function declaration when requested,
100 /// so that LLVM lowering generates a wrapper function that takes care
101 /// of ABI complications with passing in and returning MemRefs to C functions.
102 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
103                                  TypeRange resultType, ValueRange operands,
104                                  bool emitCInterface = false) {
105   MLIRContext *context = op->getContext();
106   auto module = op->getParentOfType<ModuleOp>();
107   auto result = SymbolRefAttr::get(context, name);
108   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
109   if (!func) {
110     OpBuilder moduleBuilder(module.getBodyRegion());
111     func = moduleBuilder.create<FuncOp>(
112         op->getLoc(), name,
113         FunctionType::get(context, operands.getTypes(), resultType));
114     func.setPrivate();
115     if (emitCInterface)
116       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
117   }
118   return result;
119 }
120 
121 /// Generates a call into the "swiss army knife" method of the sparse runtime
122 /// support library for materializing sparse tensors into the computation. The
123 /// method returns the call value and assigns the permutation to 'perm'.
124 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
125                         SparseTensorEncodingAttr &enc, uint32_t action,
126                         Value &perm, Value ptr = Value()) {
127   Location loc = op->getLoc();
128   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
129   SmallVector<Value, 8> params;
130   // Sparsity annotations in tensor constant form.
131   SmallVector<APInt, 4> attrs;
132   unsigned sz = enc.getDimLevelType().size();
133   for (unsigned i = 0; i < sz; i++)
134     attrs.push_back(
135         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
136   params.push_back(getTensor(rewriter, 8, loc, attrs));
137   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
138   // verification of external data, or for construction of internal data.
139   auto shape = resType.getShape();
140   SmallVector<APInt, 4> sizes;
141   for (unsigned i = 0; i < sz; i++) {
142     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
143     sizes.push_back(APInt(64, s));
144   }
145   params.push_back(getTensor(rewriter, 64, loc, sizes));
146   // Dimension order permutation array. This is the "identity" permutation by
147   // default, or otherwise the "reverse" permutation of a given ordering, so
148   // that indices can be mapped quickly to the right position.
149   SmallVector<APInt, 4> rev(sz);
150   if (AffineMap p = enc.getDimOrdering()) {
151     for (unsigned i = 0; i < sz; i++)
152       rev[p.getDimPosition(i)] = APInt(64, i);
153   } else {
154     for (unsigned i = 0; i < sz; i++)
155       rev[i] = APInt(64, i);
156   }
157   perm = getTensor(rewriter, 64, loc, rev);
158   params.push_back(perm);
159   // Secondary and primary types encoding.
160   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
161   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
162   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
163   assert(primary);
164   params.push_back(rewriter.create<arith::ConstantOp>(
165       loc, rewriter.getI64IntegerAttr(secPtr)));
166   params.push_back(rewriter.create<arith::ConstantOp>(
167       loc, rewriter.getI64IntegerAttr(secInd)));
168   params.push_back(rewriter.create<arith::ConstantOp>(
169       loc, rewriter.getI64IntegerAttr(primary)));
170   // User action and pointer.
171   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
172   if (!ptr)
173     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
174   params.push_back(rewriter.create<arith::ConstantOp>(
175       loc, rewriter.getI32IntegerAttr(action)));
176   params.push_back(ptr);
177   // Generate the call to create new tensor.
178   StringRef name = "newSparseTensor";
179   auto call = rewriter.create<CallOp>(
180       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
181       params);
182   return call.getResult(0);
183 }
184 
185 /// Generates a constant zero of the given type.
186 static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
187                      Type t) {
188   return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t));
189 }
190 
191 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
192 /// For floating types, we use the "unordered" comparator (i.e., returns
193 /// true if `v` is NaN).
194 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
195                           Value v) {
196   Type t = v.getType();
197   Value zero = getZero(rewriter, loc, t);
198   if (t.isa<FloatType>())
199     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
200                                           zero);
201   if (t.isIntOrIndex())
202     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
203                                           zero);
204   llvm_unreachable("Unknown element type");
205 }
206 
207 /// Generates the code to read the value from tensor[ivs], and conditionally
208 /// stores the indices ivs to the memory in ind. The generated code looks like
209 /// the following and the insertion point after this routine is inside the
210 /// if-then branch behind the assignment to ind. This is to ensure that the
211 /// addEltX call generated after is inside the if-then branch.
212 ///    if (tensor[ivs]!=0) {
213 ///      ind = ivs
214 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
215                                       Operation *op, Value tensor, Value ind,
216                                       ValueRange ivs) {
217   Location loc = op->getLoc();
218   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
219   Value cond = genIsNonzero(rewriter, loc, val);
220   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
221   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
222   unsigned i = 0;
223   for (auto iv : ivs) {
224     Value idx =
225         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++));
226     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
227   }
228   return val;
229 }
230 
231 /// Generates a call that adds one element to a coordinate scheme.
232 /// In particular, this generates code like the following:
233 ///   val = a[i1,..,ik];
234 ///   if val != 0
235 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
236 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
237                           Type eltType, Value ptr, Value val, Value ind,
238                           Value perm) {
239   Location loc = op->getLoc();
240   StringRef name;
241   if (eltType.isF64())
242     name = "addEltF64";
243   else if (eltType.isF32())
244     name = "addEltF32";
245   else if (eltType.isInteger(64))
246     name = "addEltI64";
247   else if (eltType.isInteger(32))
248     name = "addEltI32";
249   else if (eltType.isInteger(16))
250     name = "addEltI16";
251   else if (eltType.isInteger(8))
252     name = "addEltI8";
253   else
254     llvm_unreachable("Unknown element type");
255   SmallVector<Value, 8> params;
256   params.push_back(ptr);
257   params.push_back(val);
258   params.push_back(ind);
259   params.push_back(perm);
260   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
261   rewriter.create<CallOp>(
262       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
263       params);
264 }
265 
266 /// If the tensor is a sparse constant, generates and returns the pair of
267 /// the constants for the indices and the values.
268 static Optional<std::pair<Value, Value>>
269 genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
270                        Value tensor) {
271   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
272     if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
273       Location loc = op->getLoc();
274       DenseElementsAttr indicesAttr = attr.getIndices();
275       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
276       DenseElementsAttr valuesAttr = attr.getValues();
277       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
278       return std::make_pair(indices, values);
279     }
280   }
281   return {};
282 }
283 
284 /// Generates the code to copy the index at indices[ivs] to ind, and return
285 /// the value at value[ivs].
286 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
287                                        Operation *op, Value indices,
288                                        Value values, Value ind, ValueRange ivs,
289                                        unsigned rank) {
290   Location loc = op->getLoc();
291   for (unsigned i = 0; i < rank; i++) {
292     Value idx =
293         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i));
294     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
295                                                    ValueRange{ivs[0], idx});
296     val =
297         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
298     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
299   }
300   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
301 }
302 
303 /// Generates code to stack-allocate a `memref<?xindex>` where the `?`
304 /// is the given `rank`.  This array is intended to serve as a reusable
305 /// buffer for storing the indices of a single tensor element, to avoid
306 /// allocation in the body of loops.
307 static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
308                            int64_t rank) {
309   auto indexTp = rewriter.getIndexType();
310   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
311   Value arg =
312       rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank));
313   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Conversion rules.
318 //===----------------------------------------------------------------------===//
319 
320 /// Sparse conversion rule for returns.
321 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
322 public:
323   using OpConversionPattern::OpConversionPattern;
324   LogicalResult
325   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
326                   ConversionPatternRewriter &rewriter) const override {
327     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
328     return success();
329   }
330 };
331 
332 /// Sparse conversion rule for dimension accesses.
333 class SparseTensorToDimSizeConverter
334     : public OpConversionPattern<tensor::DimOp> {
335 public:
336   using OpConversionPattern::OpConversionPattern;
337   LogicalResult
338   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     Type resType = op.getType();
341     auto enc = getSparseTensorEncoding(op.source().getType());
342     if (!enc)
343       return failure();
344     // Permute the dim index.
345     Optional<int64_t> index = op.getConstantIndex();
346     if (!index.hasValue())
347       return failure();
348     int64_t idx = index.getValue();
349     if (AffineMap p = enc.getDimOrdering())
350       idx = p.getPermutedPosition(idx);
351     // Generate the call.
352     StringRef name = "sparseDimSize";
353     SmallVector<Value, 2> params;
354     params.push_back(adaptor.getOperands()[0]);
355     params.push_back(rewriter.create<arith::ConstantOp>(
356         op.getLoc(), rewriter.getIndexAttr(idx)));
357     rewriter.replaceOpWithNewOp<CallOp>(
358         op, resType, getFunc(op, name, resType, params), params);
359     return success();
360   }
361 };
362 
363 /// Sparse conversion rule for the new operator.
364 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
365   using OpConversionPattern::OpConversionPattern;
366   LogicalResult
367   matchAndRewrite(NewOp op, OpAdaptor adaptor,
368                   ConversionPatternRewriter &rewriter) const override {
369     Type resType = op.getType();
370     auto enc = getSparseTensorEncoding(resType);
371     if (!enc)
372       return failure();
373     Value perm;
374     rewriter.replaceOp(
375         op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
376     return success();
377   }
378 };
379 
380 /// Sparse conversion rule for the convert operator.
381 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
382   using OpConversionPattern::OpConversionPattern;
383   LogicalResult
384   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
385                   ConversionPatternRewriter &rewriter) const override {
386     Type resType = op.getType();
387     auto encDst = getSparseTensorEncoding(resType);
388     auto encSrc = getSparseTensorEncoding(op.source().getType());
389     auto src = adaptor.getOperands()[0];
390     if (encDst && encSrc) {
391       // This is a sparse => sparse conversion, which is handled as follows:
392       //   t = src->toCOO();         ; src to COO in dst order
393       //   dst = newSparseTensor(t)
394       // Using the coordinate scheme as an intermediate does not always
395       // yield the fastest conversion but avoids the need for a full
396       // O(N^2) conversion matrix.
397       Value perm;
398       Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
399       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
400       return success();
401     }
402     if (!encDst || encSrc) {
403       // TODO: sparse => dense
404       return failure();
405     }
406     // This is a dense => sparse conversion or a sparse constant in COO =>
407     // sparse conversion, which is handled as follows:
408     //   t = newSparseCOO()
409     //   ...code to fill the COO tensor t...
410     //   s = newSparseTensor(t)
411     //
412     // To fill the COO tensor from a dense tensor:
413     //   for i1 in dim1
414     //    ..
415     //     for ik in dimk
416     //       val = a[i1,..,ik]
417     //       if val != 0
418     //         t->add(val, [i1,..,ik], [p1,..,pk])
419     //
420     // To fill the COO tensor from a sparse constant in COO format:
421     //   for i in range(NNZ)
422     //     val = values[i]
423     //     [i1,..,ik] = indices[i]
424     //     t->add(val, [i1,..,ik], [p1,..,pk])
425     //
426     // Note that the dense tensor traversal code is actually implemented
427     // using MLIR IR to avoid having to expose too much low-level
428     // memref traversal details to the runtime support library.
429     // Also note that the code below only generates the "new" ops and
430     // the loop-nest per se; whereas the entire body of the innermost
431     // loop is generated by genAddElt().
432     Location loc = op->getLoc();
433     ShapedType shape = resType.cast<ShapedType>();
434     Value perm;
435     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
436     Value ind = allocaIndices(rewriter, loc, shape.getRank());
437     SmallVector<Value> lo;
438     SmallVector<Value> hi;
439     SmallVector<Value> st;
440     Value zero =
441         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
442     Value one =
443         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
444     auto indicesValues = genSplitSparseConstant(rewriter, op, src);
445     bool isCOOConstant = indicesValues.hasValue();
446     Value indices;
447     Value values;
448     if (isCOOConstant) {
449       indices = indicesValues->first;
450       values = indicesValues->second;
451       lo.push_back(zero);
452       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
453       st.push_back(one);
454     } else {
455       for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
456         lo.push_back(zero);
457         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
458         st.push_back(one);
459       }
460     }
461     Type eltType = shape.getElementType();
462     unsigned rank = shape.getRank();
463     scf::buildLoopNest(
464         rewriter, op.getLoc(), lo, hi, st, {},
465         [&](OpBuilder &builder, Location loc, ValueRange ivs,
466             ValueRange args) -> scf::ValueVector {
467           Value val;
468           if (isCOOConstant)
469             val = genIndexAndValueForSparse(rewriter, op, indices, values, ind,
470                                             ivs, rank);
471           else
472             val = genIndexAndValueForDense(rewriter, op, src, ind, ivs);
473           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
474           return {};
475         });
476     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
477     return success();
478   }
479 };
480 
481 /// Sparse conversion rule for the release operator.
482 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
483 public:
484   using OpConversionPattern::OpConversionPattern;
485   LogicalResult
486   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
487                   ConversionPatternRewriter &rewriter) const override {
488     StringRef name = "delSparseTensor";
489     TypeRange none;
490     rewriter.create<CallOp>(op.getLoc(), none,
491                             getFunc(op, name, none, adaptor.getOperands()),
492                             adaptor.getOperands());
493     rewriter.eraseOp(op);
494     return success();
495   }
496 };
497 
498 /// Sparse conversion rule for pointer accesses.
499 class SparseTensorToPointersConverter
500     : public OpConversionPattern<ToPointersOp> {
501 public:
502   using OpConversionPattern::OpConversionPattern;
503   LogicalResult
504   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
505                   ConversionPatternRewriter &rewriter) const override {
506     Type resType = op.getType();
507     Type eltType = resType.cast<ShapedType>().getElementType();
508     StringRef name;
509     if (eltType.isIndex())
510       name = "sparsePointers"; // 64-bit, but its own name for unique signature
511     else if (eltType.isInteger(64))
512       name = "sparsePointers64";
513     else if (eltType.isInteger(32))
514       name = "sparsePointers32";
515     else if (eltType.isInteger(16))
516       name = "sparsePointers16";
517     else if (eltType.isInteger(8))
518       name = "sparsePointers8";
519     else
520       return failure();
521     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
522                                         getFunc(op, name, resType,
523                                                 adaptor.getOperands(),
524                                                 /*emitCInterface=*/true),
525                                         adaptor.getOperands());
526     return success();
527   }
528 };
529 
530 /// Sparse conversion rule for index accesses.
531 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
532 public:
533   using OpConversionPattern::OpConversionPattern;
534   LogicalResult
535   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
536                   ConversionPatternRewriter &rewriter) const override {
537     Type resType = op.getType();
538     Type eltType = resType.cast<ShapedType>().getElementType();
539     StringRef name;
540     if (eltType.isIndex())
541       name = "sparseIndices"; // 64-bit, but its own name for unique signature
542     else if (eltType.isInteger(64))
543       name = "sparseIndices64";
544     else if (eltType.isInteger(32))
545       name = "sparseIndices32";
546     else if (eltType.isInteger(16))
547       name = "sparseIndices16";
548     else if (eltType.isInteger(8))
549       name = "sparseIndices8";
550     else
551       return failure();
552     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
553                                         getFunc(op, name, resType,
554                                                 adaptor.getOperands(),
555                                                 /*emitCInterface=*/true),
556                                         adaptor.getOperands());
557     return success();
558   }
559 };
560 
561 /// Sparse conversion rule for value accesses.
562 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
563 public:
564   using OpConversionPattern::OpConversionPattern;
565   LogicalResult
566   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
567                   ConversionPatternRewriter &rewriter) const override {
568     Type resType = op.getType();
569     Type eltType = resType.cast<ShapedType>().getElementType();
570     StringRef name;
571     if (eltType.isF64())
572       name = "sparseValuesF64";
573     else if (eltType.isF32())
574       name = "sparseValuesF32";
575     else if (eltType.isInteger(64))
576       name = "sparseValuesI64";
577     else if (eltType.isInteger(32))
578       name = "sparseValuesI32";
579     else if (eltType.isInteger(16))
580       name = "sparseValuesI16";
581     else if (eltType.isInteger(8))
582       name = "sparseValuesI8";
583     else
584       return failure();
585     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
586                                         getFunc(op, name, resType,
587                                                 adaptor.getOperands(),
588                                                 /*emitCInterface=*/true),
589                                         adaptor.getOperands());
590     return success();
591   }
592 };
593 
594 /// Sparse conversion rule for tensor reconstruction.
595 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
596 public:
597   using OpConversionPattern::OpConversionPattern;
598   LogicalResult
599   // Simply fold the operator into the pointer to the sparse storage scheme.
600   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
601                   ConversionPatternRewriter &rewriter) const override {
602     // Check that all arguments of the tensor reconstruction operators are calls
603     // into the support library that query exactly the same opaque pointer.
604     Value ptr;
605     for (Value op : adaptor.getOperands()) {
606       if (auto call = op.getDefiningOp<CallOp>()) {
607         Value arg = call.getOperand(0);
608         if (!arg.getType().isa<LLVM::LLVMPointerType>())
609           return failure();
610         if (!ptr)
611           ptr = arg;
612         else if (arg != ptr)
613           return failure();
614       }
615     }
616     // If a single opaque pointer is found, perform the folding.
617     if (!ptr)
618       return failure();
619     rewriter.replaceOp(op, ptr);
620     return success();
621   }
622 };
623 
624 } // namespace
625 
626 //===----------------------------------------------------------------------===//
627 // Public method for populating conversion rules.
628 //===----------------------------------------------------------------------===//
629 
630 /// Populates the given patterns list with conversion rules required for
631 /// the sparsification of linear algebra operations.
632 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
633                                                   RewritePatternSet &patterns) {
634   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
635                SparseTensorNewConverter, SparseTensorConvertConverter,
636                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
637                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
638                SparseTensorToTensorConverter>(typeConverter,
639                                               patterns.getContext());
640 }
641