1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 //===----------------------------------------------------------------------===//
33 // Helper methods.
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns internal type encoding for primary storage. Keep these
37 /// values consistent with the sparse runtime support library.
38 static unsigned getPrimaryTypeEncoding(Type tp) {
39   if (tp.isF64())
40     return 1;
41   if (tp.isF32())
42     return 2;
43   if (tp.isInteger(64))
44     return 3;
45   if (tp.isInteger(32))
46     return 4;
47   if (tp.isInteger(16))
48     return 5;
49   if (tp.isInteger(8))
50     return 6;
51   return 0;
52 }
53 
54 /// Returns internal type encoding for overhead storage. Keep these
55 /// values consistent with the sparse runtime support library.
56 static unsigned getOverheadTypeEncoding(unsigned width) {
57   switch (width) {
58   default:
59     return 1;
60   case 32:
61     return 2;
62   case 16:
63     return 3;
64   case 8:
65     return 4;
66   }
67 }
68 
69 /// Returns internal dimension level type encoding. Keep these
70 /// values consistent with the sparse runtime support library.
71 static unsigned
72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
73   switch (dlt) {
74   case SparseTensorEncodingAttr::DimLevelType::Dense:
75     return 0;
76   case SparseTensorEncodingAttr::DimLevelType::Compressed:
77     return 1;
78   case SparseTensorEncodingAttr::DimLevelType::Singleton:
79     return 2;
80   }
81   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
82 }
83 
84 /// Generates a constant zero of the given type.
85 inline static Value constantZero(ConversionPatternRewriter &rewriter,
86                                  Location loc, Type t) {
87   return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
88 }
89 
90 /// Generates a constant of `index` type.
91 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
92                                   Location loc, unsigned i) {
93   return rewriter.create<arith::ConstantIndexOp>(loc, i);
94 }
95 
96 /// Generates a constant of `i64` type.
97 inline static Value constantI64(ConversionPatternRewriter &rewriter,
98                                 Location loc, int64_t i) {
99   return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
100 }
101 
102 /// Generates a constant of `i32` type.
103 inline static Value constantI32(ConversionPatternRewriter &rewriter,
104                                 Location loc, int32_t i) {
105   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
106 }
107 
108 /// Returns integers of given width and values as a constant tensor.
109 /// We cast the static shape into a dynamic shape to ensure that the
110 /// method signature remains uniform across different tensor dimensions.
111 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
112                        Location loc, ArrayRef<APInt> values) {
113   Type etp = rewriter.getIntegerType(width);
114   unsigned sz = values.size();
115   RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
116   RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
117   auto elts = rewriter.create<arith::ConstantOp>(
118       loc, DenseElementsAttr::get(tt1, values));
119   return rewriter.create<tensor::CastOp>(loc, tt2, elts);
120 }
121 
122 /// Returns a function reference (first hit also inserts into module). Sets
123 /// the "_emit_c_interface" on the function declaration when requested,
124 /// so that LLVM lowering generates a wrapper function that takes care
125 /// of ABI complications with passing in and returning MemRefs to C functions.
126 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
127                                  TypeRange resultType, ValueRange operands,
128                                  bool emitCInterface = false) {
129   MLIRContext *context = op->getContext();
130   auto module = op->getParentOfType<ModuleOp>();
131   auto result = SymbolRefAttr::get(context, name);
132   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
133   if (!func) {
134     OpBuilder moduleBuilder(module.getBodyRegion());
135     func = moduleBuilder.create<FuncOp>(
136         op->getLoc(), name,
137         FunctionType::get(context, operands.getTypes(), resultType));
138     func.setPrivate();
139     if (emitCInterface)
140       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
141   }
142   return result;
143 }
144 
145 /// Generates a call into the "swiss army knife" method of the sparse runtime
146 /// support library for materializing sparse tensors into the computation. The
147 /// method returns the call value and assigns the permutation to 'perm'.
148 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
149                         SparseTensorEncodingAttr &enc, uint32_t action,
150                         Value &perm, Value ptr = Value()) {
151   Location loc = op->getLoc();
152   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
153   SmallVector<Value, 8> params;
154   // Sparsity annotations in tensor constant form.
155   SmallVector<APInt, 4> attrs;
156   unsigned sz = enc.getDimLevelType().size();
157   for (unsigned i = 0; i < sz; i++)
158     attrs.push_back(
159         APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
160   params.push_back(getTensor(rewriter, 8, loc, attrs));
161   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
162   // verification of external data, or for construction of internal data.
163   auto shape = resType.getShape();
164   SmallVector<APInt, 4> sizes;
165   for (unsigned i = 0; i < sz; i++) {
166     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
167     sizes.push_back(APInt(64, s));
168   }
169   params.push_back(getTensor(rewriter, 64, loc, sizes));
170   // Dimension order permutation array. This is the "identity" permutation by
171   // default, or otherwise the "reverse" permutation of a given ordering, so
172   // that indices can be mapped quickly to the right position.
173   SmallVector<APInt, 4> rev(sz);
174   if (AffineMap p = enc.getDimOrdering()) {
175     for (unsigned i = 0; i < sz; i++)
176       rev[p.getDimPosition(i)] = APInt(64, i);
177   } else {
178     for (unsigned i = 0; i < sz; i++)
179       rev[i] = APInt(64, i);
180   }
181   perm = getTensor(rewriter, 64, loc, rev);
182   params.push_back(perm);
183   // Secondary and primary types encoding.
184   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
185   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
186   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
187   assert(primary);
188   params.push_back(constantI64(rewriter, loc, secPtr));
189   params.push_back(constantI64(rewriter, loc, secInd));
190   params.push_back(constantI64(rewriter, loc, primary));
191   // User action and pointer.
192   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
193   if (!ptr)
194     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
195   params.push_back(constantI32(rewriter, loc, action));
196   params.push_back(ptr);
197   // Generate the call to create new tensor.
198   StringRef name = "newSparseTensor";
199   auto call = rewriter.create<CallOp>(
200       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
201       params);
202   return call.getResult(0);
203 }
204 
205 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
206 /// For floating types, we use the "unordered" comparator (i.e., returns
207 /// true if `v` is NaN).
208 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
209                           Value v) {
210   Type t = v.getType();
211   Value zero = constantZero(rewriter, loc, t);
212   if (t.isa<FloatType>())
213     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
214                                           zero);
215   if (t.isIntOrIndex())
216     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
217                                           zero);
218   llvm_unreachable("Unknown element type");
219 }
220 
221 /// Generates the code to read the value from tensor[ivs], and conditionally
222 /// stores the indices ivs to the memory in ind. The generated code looks like
223 /// the following and the insertion point after this routine is inside the
224 /// if-then branch behind the assignment to ind. This is to ensure that the
225 /// addEltX call generated after is inside the if-then branch.
226 ///    if (tensor[ivs]!=0) {
227 ///      ind = ivs
228 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
229                                       Location loc, Value tensor, Value ind,
230                                       ValueRange ivs) {
231   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
232   Value cond = genIsNonzero(rewriter, loc, val);
233   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
234   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
235   unsigned i = 0;
236   for (auto iv : ivs) {
237     Value idx = constantIndex(rewriter, loc, i++);
238     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
239   }
240   return val;
241 }
242 
243 /// Generates a call that adds one element to a coordinate scheme.
244 /// In particular, this generates code like the following:
245 ///   val = a[i1,..,ik];
246 ///   if val != 0
247 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
248 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
249                           Type eltType, Value ptr, Value val, Value ind,
250                           Value perm) {
251   Location loc = op->getLoc();
252   StringRef name;
253   if (eltType.isF64())
254     name = "addEltF64";
255   else if (eltType.isF32())
256     name = "addEltF32";
257   else if (eltType.isInteger(64))
258     name = "addEltI64";
259   else if (eltType.isInteger(32))
260     name = "addEltI32";
261   else if (eltType.isInteger(16))
262     name = "addEltI16";
263   else if (eltType.isInteger(8))
264     name = "addEltI8";
265   else
266     llvm_unreachable("Unknown element type");
267   SmallVector<Value, 8> params;
268   params.push_back(ptr);
269   params.push_back(val);
270   params.push_back(ind);
271   params.push_back(perm);
272   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
273   rewriter.create<CallOp>(
274       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
275       params);
276 }
277 
278 /// If the tensor is a sparse constant, generates and returns the pair of
279 /// the constants for the indices and the values.
280 static Optional<std::pair<Value, Value>>
281 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
282                        Value tensor) {
283   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
284     if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
285       DenseElementsAttr indicesAttr = attr.getIndices();
286       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
287       DenseElementsAttr valuesAttr = attr.getValues();
288       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
289       return std::make_pair(indices, values);
290     }
291   }
292   return {};
293 }
294 
295 /// Generates the code to copy the index at indices[ivs] to ind, and return
296 /// the value at value[ivs].
297 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
298                                        Location loc, Value indices,
299                                        Value values, Value ind, ValueRange ivs,
300                                        unsigned rank) {
301   for (unsigned i = 0; i < rank; i++) {
302     Value idx = constantIndex(rewriter, loc, i);
303     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
304                                                    ValueRange{ivs[0], idx});
305     val =
306         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
307     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
308   }
309   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
310 }
311 
312 /// Generates code to stack-allocate a `memref<?xindex>` where the `?`
313 /// is the given `rank`.  This array is intended to serve as a reusable
314 /// buffer for storing the indices of a single tensor element, to avoid
315 /// allocation in the body of loops.
316 static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
317                            int64_t rank) {
318   auto indexTp = rewriter.getIndexType();
319   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
320   Value arg = constantIndex(rewriter, loc, rank);
321   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // Conversion rules.
326 //===----------------------------------------------------------------------===//
327 
328 /// Sparse conversion rule for returns.
329 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
330 public:
331   using OpConversionPattern::OpConversionPattern;
332   LogicalResult
333   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
334                   ConversionPatternRewriter &rewriter) const override {
335     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
336     return success();
337   }
338 };
339 
340 /// Sparse conversion rule for dimension accesses.
341 class SparseTensorToDimSizeConverter
342     : public OpConversionPattern<tensor::DimOp> {
343 public:
344   using OpConversionPattern::OpConversionPattern;
345   LogicalResult
346   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
347                   ConversionPatternRewriter &rewriter) const override {
348     Type resType = op.getType();
349     auto enc = getSparseTensorEncoding(op.source().getType());
350     if (!enc)
351       return failure();
352     // Permute the dim index.
353     Optional<int64_t> index = op.getConstantIndex();
354     if (!index.hasValue())
355       return failure();
356     int64_t idx = index.getValue();
357     if (AffineMap p = enc.getDimOrdering())
358       idx = p.getPermutedPosition(idx);
359     // Generate the call.
360     StringRef name = "sparseDimSize";
361     SmallVector<Value, 2> params;
362     params.push_back(adaptor.getOperands()[0]);
363     params.push_back(constantIndex(rewriter, op.getLoc(), idx));
364     rewriter.replaceOpWithNewOp<CallOp>(
365         op, resType, getFunc(op, name, resType, params), params);
366     return success();
367   }
368 };
369 
370 /// Sparse conversion rule for the new operator.
371 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
372   using OpConversionPattern::OpConversionPattern;
373   LogicalResult
374   matchAndRewrite(NewOp op, OpAdaptor adaptor,
375                   ConversionPatternRewriter &rewriter) const override {
376     Type resType = op.getType();
377     auto enc = getSparseTensorEncoding(resType);
378     if (!enc)
379       return failure();
380     Value perm;
381     rewriter.replaceOp(
382         op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
383     return success();
384   }
385 };
386 
387 /// Sparse conversion rule for the convert operator.
388 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
389   using OpConversionPattern::OpConversionPattern;
390   LogicalResult
391   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     Type resType = op.getType();
394     auto encDst = getSparseTensorEncoding(resType);
395     auto encSrc = getSparseTensorEncoding(op.source().getType());
396     auto src = adaptor.getOperands()[0];
397     if (encDst && encSrc) {
398       // This is a sparse => sparse conversion, which is handled as follows:
399       //   t = src->toCOO();         ; src to COO in dst order
400       //   dst = newSparseTensor(t)
401       // Using the coordinate scheme as an intermediate does not always
402       // yield the fastest conversion but avoids the need for a full
403       // O(N^2) conversion matrix.
404       Value perm;
405       Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
406       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
407       return success();
408     }
409     if (!encDst || encSrc) {
410       // TODO: sparse => dense
411       return failure();
412     }
413     // This is a dense => sparse conversion or a sparse constant in COO =>
414     // sparse conversion, which is handled as follows:
415     //   t = newSparseCOO()
416     //   ...code to fill the COO tensor t...
417     //   s = newSparseTensor(t)
418     //
419     // To fill the COO tensor from a dense tensor:
420     //   for i1 in dim1
421     //    ..
422     //     for ik in dimk
423     //       val = a[i1,..,ik]
424     //       if val != 0
425     //         t->add(val, [i1,..,ik], [p1,..,pk])
426     //
427     // To fill the COO tensor from a sparse constant in COO format:
428     //   for i in range(NNZ)
429     //     val = values[i]
430     //     [i1,..,ik] = indices[i]
431     //     t->add(val, [i1,..,ik], [p1,..,pk])
432     //
433     // Note that the dense tensor traversal code is actually implemented
434     // using MLIR IR to avoid having to expose too much low-level
435     // memref traversal details to the runtime support library.
436     // Also note that the code below only generates the "new" ops and
437     // the loop-nest per se; whereas the entire body of the innermost
438     // loop is generated by genAddElt().
439     Location loc = op->getLoc();
440     ShapedType shape = resType.cast<ShapedType>();
441     Value perm;
442     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
443     Value ind = allocaIndices(rewriter, loc, shape.getRank());
444     SmallVector<Value> lo;
445     SmallVector<Value> hi;
446     SmallVector<Value> st;
447     Value zero = constantIndex(rewriter, loc, 0);
448     Value one = constantIndex(rewriter, loc, 1);
449     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
450     bool isCOOConstant = indicesValues.hasValue();
451     Value indices;
452     Value values;
453     if (isCOOConstant) {
454       indices = indicesValues->first;
455       values = indicesValues->second;
456       lo.push_back(zero);
457       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
458       st.push_back(one);
459     } else {
460       for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
461         lo.push_back(zero);
462         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
463         st.push_back(one);
464       }
465     }
466     Type eltType = shape.getElementType();
467     unsigned rank = shape.getRank();
468     scf::buildLoopNest(
469         rewriter, op.getLoc(), lo, hi, st, {},
470         [&](OpBuilder &builder, Location loc, ValueRange ivs,
471             ValueRange args) -> scf::ValueVector {
472           Value val;
473           if (isCOOConstant)
474             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
475                                             ivs, rank);
476           else
477             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
478           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
479           return {};
480         });
481     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
482     return success();
483   }
484 };
485 
486 /// Sparse conversion rule for the release operator.
487 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
488 public:
489   using OpConversionPattern::OpConversionPattern;
490   LogicalResult
491   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
492                   ConversionPatternRewriter &rewriter) const override {
493     StringRef name = "delSparseTensor";
494     TypeRange none;
495     rewriter.create<CallOp>(op.getLoc(), none,
496                             getFunc(op, name, none, adaptor.getOperands()),
497                             adaptor.getOperands());
498     rewriter.eraseOp(op);
499     return success();
500   }
501 };
502 
503 /// Sparse conversion rule for pointer accesses.
504 class SparseTensorToPointersConverter
505     : public OpConversionPattern<ToPointersOp> {
506 public:
507   using OpConversionPattern::OpConversionPattern;
508   LogicalResult
509   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
510                   ConversionPatternRewriter &rewriter) const override {
511     Type resType = op.getType();
512     Type eltType = resType.cast<ShapedType>().getElementType();
513     StringRef name;
514     if (eltType.isIndex())
515       name = "sparsePointers"; // 64-bit, but its own name for unique signature
516     else if (eltType.isInteger(64))
517       name = "sparsePointers64";
518     else if (eltType.isInteger(32))
519       name = "sparsePointers32";
520     else if (eltType.isInteger(16))
521       name = "sparsePointers16";
522     else if (eltType.isInteger(8))
523       name = "sparsePointers8";
524     else
525       return failure();
526     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
527                                         getFunc(op, name, resType,
528                                                 adaptor.getOperands(),
529                                                 /*emitCInterface=*/true),
530                                         adaptor.getOperands());
531     return success();
532   }
533 };
534 
535 /// Sparse conversion rule for index accesses.
536 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
537 public:
538   using OpConversionPattern::OpConversionPattern;
539   LogicalResult
540   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
541                   ConversionPatternRewriter &rewriter) const override {
542     Type resType = op.getType();
543     Type eltType = resType.cast<ShapedType>().getElementType();
544     StringRef name;
545     if (eltType.isIndex())
546       name = "sparseIndices"; // 64-bit, but its own name for unique signature
547     else if (eltType.isInteger(64))
548       name = "sparseIndices64";
549     else if (eltType.isInteger(32))
550       name = "sparseIndices32";
551     else if (eltType.isInteger(16))
552       name = "sparseIndices16";
553     else if (eltType.isInteger(8))
554       name = "sparseIndices8";
555     else
556       return failure();
557     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
558                                         getFunc(op, name, resType,
559                                                 adaptor.getOperands(),
560                                                 /*emitCInterface=*/true),
561                                         adaptor.getOperands());
562     return success();
563   }
564 };
565 
566 /// Sparse conversion rule for value accesses.
567 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
568 public:
569   using OpConversionPattern::OpConversionPattern;
570   LogicalResult
571   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
572                   ConversionPatternRewriter &rewriter) const override {
573     Type resType = op.getType();
574     Type eltType = resType.cast<ShapedType>().getElementType();
575     StringRef name;
576     if (eltType.isF64())
577       name = "sparseValuesF64";
578     else if (eltType.isF32())
579       name = "sparseValuesF32";
580     else if (eltType.isInteger(64))
581       name = "sparseValuesI64";
582     else if (eltType.isInteger(32))
583       name = "sparseValuesI32";
584     else if (eltType.isInteger(16))
585       name = "sparseValuesI16";
586     else if (eltType.isInteger(8))
587       name = "sparseValuesI8";
588     else
589       return failure();
590     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
591                                         getFunc(op, name, resType,
592                                                 adaptor.getOperands(),
593                                                 /*emitCInterface=*/true),
594                                         adaptor.getOperands());
595     return success();
596   }
597 };
598 
599 /// Sparse conversion rule for tensor reconstruction.
600 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
601 public:
602   using OpConversionPattern::OpConversionPattern;
603   LogicalResult
604   // Simply fold the operator into the pointer to the sparse storage scheme.
605   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
606                   ConversionPatternRewriter &rewriter) const override {
607     // Check that all arguments of the tensor reconstruction operators are calls
608     // into the support library that query exactly the same opaque pointer.
609     Value ptr;
610     for (Value op : adaptor.getOperands()) {
611       if (auto call = op.getDefiningOp<CallOp>()) {
612         Value arg = call.getOperand(0);
613         if (!arg.getType().isa<LLVM::LLVMPointerType>())
614           return failure();
615         if (!ptr)
616           ptr = arg;
617         else if (arg != ptr)
618           return failure();
619       }
620     }
621     // If a single opaque pointer is found, perform the folding.
622     if (!ptr)
623       return failure();
624     rewriter.replaceOp(op, ptr);
625     return success();
626   }
627 };
628 
629 } // namespace
630 
631 //===----------------------------------------------------------------------===//
632 // Public method for populating conversion rules.
633 //===----------------------------------------------------------------------===//
634 
635 /// Populates the given patterns list with conversion rules required for
636 /// the sparsification of linear algebra operations.
637 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
638                                                   RewritePatternSet &patterns) {
639   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
640                SparseTensorNewConverter, SparseTensorConvertConverter,
641                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
642                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
643                SparseTensorToTensorConverter>(typeConverter,
644                                               patterns.getContext());
645 }
646