1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 /// New tensor storage action. Keep these values consistent with
33 /// the sparse runtime support library.
34 enum Action : uint32_t {
35   kEmpty = 0,
36   kFromFile = 1,
37   kFromCOO = 2,
38   kEmptyCOO = 3,
39   kToCOO = 4
40 };
41 
42 //===----------------------------------------------------------------------===//
43 // Helper methods.
44 //===----------------------------------------------------------------------===//
45 
46 /// Returns internal type encoding for primary storage. Keep these
47 /// values consistent with the sparse runtime support library.
48 static unsigned getPrimaryTypeEncoding(Type tp) {
49   if (tp.isF64())
50     return 1;
51   if (tp.isF32())
52     return 2;
53   if (tp.isInteger(64))
54     return 3;
55   if (tp.isInteger(32))
56     return 4;
57   if (tp.isInteger(16))
58     return 5;
59   if (tp.isInteger(8))
60     return 6;
61   return 0;
62 }
63 
64 /// Returns internal type encoding for overhead storage. Keep these
65 /// values consistent with the sparse runtime support library.
66 static unsigned getOverheadTypeEncoding(unsigned width) {
67   switch (width) {
68   default:
69     return 1;
70   case 32:
71     return 2;
72   case 16:
73     return 3;
74   case 8:
75     return 4;
76   }
77 }
78 
79 /// Returns internal dimension level type encoding. Keep these
80 /// values consistent with the sparse runtime support library.
81 static unsigned
82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
83   switch (dlt) {
84   case SparseTensorEncodingAttr::DimLevelType::Dense:
85     return 0;
86   case SparseTensorEncodingAttr::DimLevelType::Compressed:
87     return 1;
88   case SparseTensorEncodingAttr::DimLevelType::Singleton:
89     return 2;
90   }
91   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
92 }
93 
94 /// Generates a constant zero of the given type.
95 inline static Value constantZero(ConversionPatternRewriter &rewriter,
96                                  Location loc, Type t) {
97   return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
98 }
99 
100 /// Generates a constant of `index` type.
101 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
102                                   Location loc, unsigned i) {
103   return rewriter.create<arith::ConstantIndexOp>(loc, i);
104 }
105 
106 /// Generates a constant of `i64` type.
107 inline static Value constantI64(ConversionPatternRewriter &rewriter,
108                                 Location loc, int64_t i) {
109   return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
110 }
111 
112 /// Generates a constant of `i32` type.
113 inline static Value constantI32(ConversionPatternRewriter &rewriter,
114                                 Location loc, int32_t i) {
115   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
116 }
117 
118 /// Generates a constant of `i8` type.
119 inline static Value constantI8(ConversionPatternRewriter &rewriter,
120                                Location loc, int8_t i) {
121   return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
122 }
123 
124 /// Returns a function reference (first hit also inserts into module). Sets
125 /// the "_emit_c_interface" on the function declaration when requested,
126 /// so that LLVM lowering generates a wrapper function that takes care
127 /// of ABI complications with passing in and returning MemRefs to C functions.
128 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
129                                  TypeRange resultType, ValueRange operands,
130                                  bool emitCInterface = false) {
131   MLIRContext *context = op->getContext();
132   auto module = op->getParentOfType<ModuleOp>();
133   auto result = SymbolRefAttr::get(context, name);
134   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
135   if (!func) {
136     OpBuilder moduleBuilder(module.getBodyRegion());
137     func = moduleBuilder.create<FuncOp>(
138         op->getLoc(), name,
139         FunctionType::get(context, operands.getTypes(), resultType));
140     func.setPrivate();
141     if (emitCInterface)
142       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
143   }
144   return result;
145 }
146 
147 /// Generates a temporary buffer of the given size and type.
148 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
149                        unsigned sz, Type tp) {
150   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
151   Value a = constantIndex(rewriter, loc, sz);
152   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
153 }
154 
155 /// Fills a temporary buffer of the given type with arguments.
156 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
157                        ArrayRef<Value> values) {
158   unsigned sz = values.size();
159   assert(sz >= 1);
160   Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
161   for (unsigned i = 0; i < sz; i++) {
162     Value idx = constantIndex(rewriter, loc, i);
163     rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
164   }
165   return buffer;
166 }
167 
168 /// Generates a call into the "swiss army knife" method of the sparse runtime
169 /// support library for materializing sparse tensors into the computation. The
170 /// method returns the call value and assigns the permutation to 'perm'.
171 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
172                         SparseTensorEncodingAttr &enc, uint32_t action,
173                         Value &perm, ValueRange szs, Value ptr = Value()) {
174   Location loc = op->getLoc();
175   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
176   SmallVector<Value, 8> params;
177   // Sparsity annotations in tensor constant form.
178   SmallVector<Value, 4> attrs;
179   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
180   unsigned sz = dlt.size();
181   for (unsigned i = 0; i < sz; i++)
182     attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
183   params.push_back(genBuffer(rewriter, loc, attrs));
184   // Dimension sizes array of the enveloping *dense* tensor. Useful for either
185   // verification of external data, or for construction of internal data.
186   auto shape = resType.getShape();
187   SmallVector<Value, 4> sizes;
188   if (szs.size() > 0) {
189     for (Value s : szs)
190       sizes.push_back(
191           rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type()));
192   } else {
193     for (unsigned i = 0; i < sz; i++) {
194       uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
195       sizes.push_back(constantI64(rewriter, loc, s));
196     }
197   }
198   params.push_back(genBuffer(rewriter, loc, sizes));
199   // Dimension order permutation array. This is the "identity" permutation by
200   // default, or otherwise the "reverse" permutation of a given ordering, so
201   // that indices can be mapped quickly to the right position.
202   SmallVector<Value, 4> rev(sz);
203   if (AffineMap p = enc.getDimOrdering()) {
204     for (unsigned i = 0; i < sz; i++)
205       rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i);
206   } else {
207     for (unsigned i = 0; i < sz; i++)
208       rev[i] = constantI64(rewriter, loc, i);
209   }
210   perm = genBuffer(rewriter, loc, rev);
211   params.push_back(perm);
212   // Secondary and primary types encoding.
213   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
214   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
215   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
216   assert(primary);
217   params.push_back(constantI64(rewriter, loc, secPtr));
218   params.push_back(constantI64(rewriter, loc, secInd));
219   params.push_back(constantI64(rewriter, loc, primary));
220   // User action and pointer.
221   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
222   if (!ptr)
223     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
224   params.push_back(constantI32(rewriter, loc, action));
225   params.push_back(ptr);
226   // Generate the call to create new tensor.
227   StringRef name = "newSparseTensor";
228   auto call = rewriter.create<CallOp>(
229       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
230       params);
231   return call.getResult(0);
232 }
233 
234 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
235 /// For floating types, we use the "unordered" comparator (i.e., returns
236 /// true if `v` is NaN).
237 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
238                           Value v) {
239   Type t = v.getType();
240   Value zero = constantZero(rewriter, loc, t);
241   if (t.isa<FloatType>())
242     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
243                                           zero);
244   if (t.isIntOrIndex())
245     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
246                                           zero);
247   llvm_unreachable("Unknown element type");
248 }
249 
250 /// Generates the code to read the value from tensor[ivs], and conditionally
251 /// stores the indices ivs to the memory in ind. The generated code looks like
252 /// the following and the insertion point after this routine is inside the
253 /// if-then branch behind the assignment to ind. This is to ensure that the
254 /// addEltX call generated after is inside the if-then branch.
255 ///    if (tensor[ivs]!=0) {
256 ///      ind = ivs
257 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
258                                       Location loc, Value tensor, Value ind,
259                                       ValueRange ivs) {
260   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
261   Value cond = genIsNonzero(rewriter, loc, val);
262   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
263   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
264   unsigned i = 0;
265   for (auto iv : ivs) {
266     Value idx = constantIndex(rewriter, loc, i++);
267     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
268   }
269   return val;
270 }
271 
272 /// Generates a call that adds one element to a coordinate scheme.
273 /// In particular, this generates code like the following:
274 ///   val = a[i1,..,ik];
275 ///   if val != 0
276 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
277 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
278                           Type eltType, Value ptr, Value val, Value ind,
279                           Value perm) {
280   Location loc = op->getLoc();
281   StringRef name;
282   if (eltType.isF64())
283     name = "addEltF64";
284   else if (eltType.isF32())
285     name = "addEltF32";
286   else if (eltType.isInteger(64))
287     name = "addEltI64";
288   else if (eltType.isInteger(32))
289     name = "addEltI32";
290   else if (eltType.isInteger(16))
291     name = "addEltI16";
292   else if (eltType.isInteger(8))
293     name = "addEltI8";
294   else
295     llvm_unreachable("Unknown element type");
296   SmallVector<Value, 8> params;
297   params.push_back(ptr);
298   params.push_back(val);
299   params.push_back(ind);
300   params.push_back(perm);
301   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
302   rewriter.create<CallOp>(
303       loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
304       params);
305 }
306 
307 /// If the tensor is a sparse constant, generates and returns the pair of
308 /// the constants for the indices and the values.
309 static Optional<std::pair<Value, Value>>
310 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
311                        Value tensor) {
312   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
313     if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
314       DenseElementsAttr indicesAttr = attr.getIndices();
315       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
316       DenseElementsAttr valuesAttr = attr.getValues();
317       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
318       return std::make_pair(indices, values);
319     }
320   }
321   return {};
322 }
323 
324 /// Generates the code to copy the index at indices[ivs] to ind, and return
325 /// the value at value[ivs].
326 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
327                                        Location loc, Value indices,
328                                        Value values, Value ind, ValueRange ivs,
329                                        unsigned rank) {
330   for (unsigned i = 0; i < rank; i++) {
331     Value idx = constantIndex(rewriter, loc, i);
332     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
333                                                    ValueRange{ivs[0], idx});
334     val =
335         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
336     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
337   }
338   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // Conversion rules.
343 //===----------------------------------------------------------------------===//
344 
345 /// Sparse conversion rule for returns.
346 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
347 public:
348   using OpConversionPattern::OpConversionPattern;
349   LogicalResult
350   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
353     return success();
354   }
355 };
356 
357 /// Sparse conversion rule for dimension accesses.
358 class SparseTensorToDimSizeConverter
359     : public OpConversionPattern<tensor::DimOp> {
360 public:
361   using OpConversionPattern::OpConversionPattern;
362   LogicalResult
363   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
364                   ConversionPatternRewriter &rewriter) const override {
365     Type resType = op.getType();
366     auto enc = getSparseTensorEncoding(op.source().getType());
367     if (!enc)
368       return failure();
369     // Permute the dim index.
370     Optional<int64_t> index = op.getConstantIndex();
371     if (!index.hasValue())
372       return failure();
373     int64_t idx = index.getValue();
374     if (AffineMap p = enc.getDimOrdering())
375       idx = p.getPermutedPosition(idx);
376     // Generate the call.
377     StringRef name = "sparseDimSize";
378     SmallVector<Value, 2> params;
379     params.push_back(adaptor.getOperands()[0]);
380     params.push_back(constantIndex(rewriter, op.getLoc(), idx));
381     rewriter.replaceOpWithNewOp<CallOp>(
382         op, resType, getFunc(op, name, resType, params), params);
383     return success();
384   }
385 };
386 
387 /// Sparse conversion rule for the new operator.
388 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
389   using OpConversionPattern::OpConversionPattern;
390   LogicalResult
391   matchAndRewrite(NewOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     Type resType = op.getType();
394     auto enc = getSparseTensorEncoding(resType);
395     if (!enc)
396       return failure();
397     Value perm;
398     rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {},
399                                       adaptor.getOperands()[0]));
400     return success();
401   }
402 };
403 
404 /// Sparse conversion rule for the init operator.
405 class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
406   using OpConversionPattern::OpConversionPattern;
407   LogicalResult
408   matchAndRewrite(InitOp op, OpAdaptor adaptor,
409                   ConversionPatternRewriter &rewriter) const override {
410     Type resType = op.getType();
411     auto enc = getSparseTensorEncoding(resType);
412     if (!enc)
413       return failure();
414     Value perm;
415     rewriter.replaceOp(
416         op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands()));
417     return success();
418   }
419 };
420 
421 /// Sparse conversion rule for the convert operator.
422 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
423   using OpConversionPattern::OpConversionPattern;
424   LogicalResult
425   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
426                   ConversionPatternRewriter &rewriter) const override {
427     Type resType = op.getType();
428     auto encDst = getSparseTensorEncoding(resType);
429     auto encSrc = getSparseTensorEncoding(op.source().getType());
430     auto src = adaptor.getOperands()[0];
431     if (encDst && encSrc) {
432       // This is a sparse => sparse conversion, which is handled as follows:
433       //   t = src->toCOO();         ; src to COO in dst order
434       //   dst = newSparseTensor(t)
435       // Using the coordinate scheme as an intermediate does not always
436       // yield the fastest conversion but avoids the need for a full
437       // O(N^2) conversion matrix.
438       Value perm;
439       Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src);
440       rewriter.replaceOp(
441           op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo));
442       return success();
443     }
444     if (!encDst || encSrc) {
445       // TODO: sparse => dense
446       return failure();
447     }
448     // This is a dense => sparse conversion or a sparse constant in COO =>
449     // sparse conversion, which is handled as follows:
450     //   t = newSparseCOO()
451     //   ...code to fill the COO tensor t...
452     //   s = newSparseTensor(t)
453     //
454     // To fill the COO tensor from a dense tensor:
455     //   for i1 in dim1
456     //    ..
457     //     for ik in dimk
458     //       val = a[i1,..,ik]
459     //       if val != 0
460     //         t->add(val, [i1,..,ik], [p1,..,pk])
461     //
462     // To fill the COO tensor from a sparse constant in COO format:
463     //   for i in range(NNZ)
464     //     val = values[i]
465     //     [i1,..,ik] = indices[i]
466     //     t->add(val, [i1,..,ik], [p1,..,pk])
467     //
468     // Note that the dense tensor traversal code is actually implemented
469     // using MLIR IR to avoid having to expose too much low-level
470     // memref traversal details to the runtime support library.
471     // Also note that the code below only generates the "new" ops and
472     // the loop-nest per se; whereas the entire body of the innermost
473     // loop is generated by genAddElt().
474     Location loc = op->getLoc();
475     ShapedType shape = resType.cast<ShapedType>();
476     Value perm;
477     Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {});
478     Value ind =
479         genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType());
480     SmallVector<Value> lo;
481     SmallVector<Value> hi;
482     SmallVector<Value> st;
483     Value zero = constantIndex(rewriter, loc, 0);
484     Value one = constantIndex(rewriter, loc, 1);
485     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
486     bool isCOOConstant = indicesValues.hasValue();
487     Value indices;
488     Value values;
489     if (isCOOConstant) {
490       indices = indicesValues->first;
491       values = indicesValues->second;
492       lo.push_back(zero);
493       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
494       st.push_back(one);
495     } else {
496       for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
497         lo.push_back(zero);
498         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
499         st.push_back(one);
500       }
501     }
502     Type eltType = shape.getElementType();
503     unsigned rank = shape.getRank();
504     scf::buildLoopNest(
505         rewriter, op.getLoc(), lo, hi, st, {},
506         [&](OpBuilder &builder, Location loc, ValueRange ivs,
507             ValueRange args) -> scf::ValueVector {
508           Value val;
509           if (isCOOConstant)
510             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
511                                             ivs, rank);
512           else
513             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
514           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
515           return {};
516         });
517     rewriter.replaceOp(
518         op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr));
519     return success();
520   }
521 };
522 
523 /// Sparse conversion rule for the release operator.
524 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
525 public:
526   using OpConversionPattern::OpConversionPattern;
527   LogicalResult
528   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
529                   ConversionPatternRewriter &rewriter) const override {
530     StringRef name = "delSparseTensor";
531     TypeRange none;
532     rewriter.create<CallOp>(op.getLoc(), none,
533                             getFunc(op, name, none, adaptor.getOperands()),
534                             adaptor.getOperands());
535     rewriter.eraseOp(op);
536     return success();
537   }
538 };
539 
540 /// Sparse conversion rule for pointer accesses.
541 class SparseTensorToPointersConverter
542     : public OpConversionPattern<ToPointersOp> {
543 public:
544   using OpConversionPattern::OpConversionPattern;
545   LogicalResult
546   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
547                   ConversionPatternRewriter &rewriter) const override {
548     Type resType = op.getType();
549     Type eltType = resType.cast<ShapedType>().getElementType();
550     StringRef name;
551     if (eltType.isIndex())
552       name = "sparsePointers"; // 64-bit, but its own name for unique signature
553     else if (eltType.isInteger(64))
554       name = "sparsePointers64";
555     else if (eltType.isInteger(32))
556       name = "sparsePointers32";
557     else if (eltType.isInteger(16))
558       name = "sparsePointers16";
559     else if (eltType.isInteger(8))
560       name = "sparsePointers8";
561     else
562       return failure();
563     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
564                                         getFunc(op, name, resType,
565                                                 adaptor.getOperands(),
566                                                 /*emitCInterface=*/true),
567                                         adaptor.getOperands());
568     return success();
569   }
570 };
571 
572 /// Sparse conversion rule for index accesses.
573 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
574 public:
575   using OpConversionPattern::OpConversionPattern;
576   LogicalResult
577   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     Type resType = op.getType();
580     Type eltType = resType.cast<ShapedType>().getElementType();
581     StringRef name;
582     if (eltType.isIndex())
583       name = "sparseIndices"; // 64-bit, but its own name for unique signature
584     else if (eltType.isInteger(64))
585       name = "sparseIndices64";
586     else if (eltType.isInteger(32))
587       name = "sparseIndices32";
588     else if (eltType.isInteger(16))
589       name = "sparseIndices16";
590     else if (eltType.isInteger(8))
591       name = "sparseIndices8";
592     else
593       return failure();
594     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
595                                         getFunc(op, name, resType,
596                                                 adaptor.getOperands(),
597                                                 /*emitCInterface=*/true),
598                                         adaptor.getOperands());
599     return success();
600   }
601 };
602 
603 /// Sparse conversion rule for value accesses.
604 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
605 public:
606   using OpConversionPattern::OpConversionPattern;
607   LogicalResult
608   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
609                   ConversionPatternRewriter &rewriter) const override {
610     Type resType = op.getType();
611     Type eltType = resType.cast<ShapedType>().getElementType();
612     StringRef name;
613     if (eltType.isF64())
614       name = "sparseValuesF64";
615     else if (eltType.isF32())
616       name = "sparseValuesF32";
617     else if (eltType.isInteger(64))
618       name = "sparseValuesI64";
619     else if (eltType.isInteger(32))
620       name = "sparseValuesI32";
621     else if (eltType.isInteger(16))
622       name = "sparseValuesI16";
623     else if (eltType.isInteger(8))
624       name = "sparseValuesI8";
625     else
626       return failure();
627     rewriter.replaceOpWithNewOp<CallOp>(op, resType,
628                                         getFunc(op, name, resType,
629                                                 adaptor.getOperands(),
630                                                 /*emitCInterface=*/true),
631                                         adaptor.getOperands());
632     return success();
633   }
634 };
635 
636 /// Sparse conversion rule for tensor reconstruction.
637 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
638 public:
639   using OpConversionPattern::OpConversionPattern;
640   LogicalResult
641   // Simply fold the operator into the pointer to the sparse storage scheme.
642   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
643                   ConversionPatternRewriter &rewriter) const override {
644     // Check that all arguments of the tensor reconstruction operators are calls
645     // into the support library that query exactly the same opaque pointer.
646     Value ptr;
647     for (Value op : adaptor.getOperands()) {
648       if (auto call = op.getDefiningOp<CallOp>()) {
649         Value arg = call.getOperand(0);
650         if (!arg.getType().isa<LLVM::LLVMPointerType>())
651           return failure();
652         if (!ptr)
653           ptr = arg;
654         else if (arg != ptr)
655           return failure();
656       }
657     }
658     // If a single opaque pointer is found, perform the folding.
659     if (!ptr)
660       return failure();
661     rewriter.replaceOp(op, ptr);
662     return success();
663   }
664 };
665 
666 } // namespace
667 
668 //===----------------------------------------------------------------------===//
669 // Public method for populating conversion rules.
670 //===----------------------------------------------------------------------===//
671 
672 /// Populates the given patterns list with conversion rules required for
673 /// the sparsification of linear algebra operations.
674 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
675                                                   RewritePatternSet &patterns) {
676   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
677                SparseTensorNewConverter, SparseTensorInitConverter,
678                SparseTensorConvertConverter, SparseTensorReleaseConverter,
679                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
680                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
681       typeConverter, patterns.getContext());
682 }
683