1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 /// New tensor storage action. Keep these values consistent with
33 /// the sparse runtime support library.
34 enum Action : uint32_t {
35   kEmpty = 0,
36   kFromFile = 1,
37   kFromCOO = 2,
38   kEmptyCOO = 3,
39   kToCOO = 4
40 };
41 
42 //===----------------------------------------------------------------------===//
43 // Helper methods.
44 //===----------------------------------------------------------------------===//
45 
46 /// Returns internal type encoding for primary storage. Keep these
47 /// values consistent with the sparse runtime support library.
48 static uint32_t getPrimaryTypeEncoding(Type tp) {
49   if (tp.isF64())
50     return 1;
51   if (tp.isF32())
52     return 2;
53   if (tp.isInteger(64))
54     return 3;
55   if (tp.isInteger(32))
56     return 4;
57   if (tp.isInteger(16))
58     return 5;
59   if (tp.isInteger(8))
60     return 6;
61   return 0;
62 }
63 
64 /// Returns internal type encoding for overhead storage. Keep these
65 /// values consistent with the sparse runtime support library.
66 static uint32_t getOverheadTypeEncoding(unsigned width) {
67   switch (width) {
68   default:
69     return 1;
70   case 32:
71     return 2;
72   case 16:
73     return 3;
74   case 8:
75     return 4;
76   }
77 }
78 
79 /// Returns internal dimension level type encoding. Keep these
80 /// values consistent with the sparse runtime support library.
81 static uint32_t
82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
83   switch (dlt) {
84   case SparseTensorEncodingAttr::DimLevelType::Dense:
85     return 0;
86   case SparseTensorEncodingAttr::DimLevelType::Compressed:
87     return 1;
88   case SparseTensorEncodingAttr::DimLevelType::Singleton:
89     return 2;
90   }
91   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
92 }
93 
94 /// Generates a constant zero of the given type.
95 inline static Value constantZero(ConversionPatternRewriter &rewriter,
96                                  Location loc, Type t) {
97   return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
98 }
99 
100 /// Generates a constant of `index` type.
101 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
102                                   Location loc, int64_t i) {
103   return rewriter.create<arith::ConstantIndexOp>(loc, i);
104 }
105 
106 /// Generates a constant of `i32` type.
107 inline static Value constantI32(ConversionPatternRewriter &rewriter,
108                                 Location loc, int32_t i) {
109   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
110 }
111 
112 /// Generates a constant of `i8` type.
113 inline static Value constantI8(ConversionPatternRewriter &rewriter,
114                                Location loc, int8_t i) {
115   return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
116 }
117 
118 /// Returns a function reference (first hit also inserts into module). Sets
119 /// the "_emit_c_interface" on the function declaration when requested,
120 /// so that LLVM lowering generates a wrapper function that takes care
121 /// of ABI complications with passing in and returning MemRefs to C functions.
122 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
123                                  TypeRange resultType, ValueRange operands,
124                                  bool emitCInterface = false) {
125   MLIRContext *context = op->getContext();
126   auto module = op->getParentOfType<ModuleOp>();
127   auto result = SymbolRefAttr::get(context, name);
128   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
129   if (!func) {
130     OpBuilder moduleBuilder(module.getBodyRegion());
131     func = moduleBuilder.create<FuncOp>(
132         op->getLoc(), name,
133         FunctionType::get(context, operands.getTypes(), resultType));
134     func.setPrivate();
135     if (emitCInterface)
136       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
137   }
138   return result;
139 }
140 
141 /// Generates dimension size call.
142 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
143                             SparseTensorEncodingAttr &enc, Value src,
144                             int64_t idx) {
145   // Permute the index according to an optional dimension ordering.
146   if (AffineMap p = enc.getDimOrdering())
147     idx = p.getPermutedPosition(idx);
148   // Generate the call.
149   Location loc = op->getLoc();
150   StringRef name = "sparseDimSize";
151   SmallVector<Value, 2> params;
152   params.push_back(src);
153   params.push_back(constantIndex(rewriter, loc, idx));
154   Type iTp = rewriter.getIndexType();
155   auto fn = getFunc(op, name, iTp, params);
156   return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
157 }
158 
159 /// Generates a call into the "swiss army knife" method of the sparse runtime
160 /// support library for materializing sparse tensors into the computation.
161 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
162                         ArrayRef<Value> params) {
163   Location loc = op->getLoc();
164   StringRef name = "newSparseTensor";
165   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
166   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
167   auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
168   return call.getResult(0);
169 }
170 
171 /// Populates given sizes array from type.
172 static void sizesFromType(ConversionPatternRewriter &rewriter,
173                           SmallVector<Value, 4> &sizes, Location loc,
174                           ShapedType stp) {
175   auto shape = stp.getShape();
176   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
177     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
178     sizes.push_back(constantIndex(rewriter, loc, s));
179   }
180 }
181 
182 /// Populates given sizes array from source.
183 static void sizesFromSrc(ConversionPatternRewriter &rewriter,
184                          SmallVector<Value, 4> &sizes, Location loc,
185                          Value src) {
186   ShapedType stp = src.getType().cast<ShapedType>();
187   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
188     sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
189 }
190 
191 /// Populates given sizes array from type (for static sizes) and from
192 /// an already converted into opague pointer source (for dynamic sizes).
193 static void sizesFromPtr(ConversionPatternRewriter &rewriter,
194                          SmallVector<Value, 4> &sizes, Operation *op,
195                          SparseTensorEncodingAttr &enc, ShapedType stp,
196                          Value src) {
197   auto shape = stp.getShape();
198   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
199     if (shape[i] == ShapedType::kDynamicSize)
200       sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
201     else
202       sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
203 }
204 
205 /// Generates a temporary buffer of the given size and type.
206 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
207                        unsigned sz, Type tp) {
208   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
209   Value a = constantIndex(rewriter, loc, sz);
210   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
211 }
212 
213 /// Generates a temporary buffer of the given type and given contents.
214 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
215                        ArrayRef<Value> values) {
216   unsigned sz = values.size();
217   assert(sz >= 1);
218   Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
219   for (unsigned i = 0; i < sz; i++) {
220     Value idx = constantIndex(rewriter, loc, i);
221     rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
222   }
223   return buffer;
224 }
225 
226 /// Populates parameters required to call the "swiss army knife" method of the
227 /// sparse runtime support library for materializing sparse tensors into the
228 /// computation.
229 static void newParams(ConversionPatternRewriter &rewriter,
230                       SmallVector<Value, 8> &params, Operation *op,
231                       SparseTensorEncodingAttr &enc, uint32_t action,
232                       ValueRange szs, Value ptr = Value()) {
233   Location loc = op->getLoc();
234   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
235   unsigned sz = dlt.size();
236   // Sparsity annotations.
237   SmallVector<Value, 4> attrs;
238   for (unsigned i = 0; i < sz; i++)
239     attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
240   params.push_back(genBuffer(rewriter, loc, attrs));
241   // Dimension sizes array of the enveloping tensor. Useful for either
242   // verification of external data, or for construction of internal data.
243   SmallVector<Value, 4> sizes;
244   for (Value s : szs)
245     sizes.push_back(s);
246   params.push_back(genBuffer(rewriter, loc, sizes));
247   // Dimension order permutation array. This is the "identity" permutation by
248   // default, or otherwise the "reverse" permutation of a given ordering, so
249   // that indices can be mapped quickly to the right position.
250   SmallVector<Value, 4> rev(sz);
251   if (AffineMap p = enc.getDimOrdering()) {
252     for (unsigned i = 0; i < sz; i++)
253       rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i);
254   } else {
255     for (unsigned i = 0; i < sz; i++)
256       rev[i] = constantIndex(rewriter, loc, i);
257   }
258   params.push_back(genBuffer(rewriter, loc, rev));
259   // Secondary and primary types encoding.
260   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
261   uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
262   uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
263   uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
264   assert(primary);
265   params.push_back(constantI32(rewriter, loc, secPtr));
266   params.push_back(constantI32(rewriter, loc, secInd));
267   params.push_back(constantI32(rewriter, loc, primary));
268   // User action and pointer.
269   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
270   if (!ptr)
271     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
272   params.push_back(constantI32(rewriter, loc, action));
273   params.push_back(ptr);
274 }
275 
276 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
277 /// For floating types, we use the "unordered" comparator (i.e., returns
278 /// true if `v` is NaN).
279 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
280                           Value v) {
281   Type t = v.getType();
282   Value zero = constantZero(rewriter, loc, t);
283   if (t.isa<FloatType>())
284     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
285                                           zero);
286   if (t.isIntOrIndex())
287     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
288                                           zero);
289   llvm_unreachable("Unknown element type");
290 }
291 
292 /// Generates the code to read the value from tensor[ivs], and conditionally
293 /// stores the indices ivs to the memory in ind. The generated code looks like
294 /// the following and the insertion point after this routine is inside the
295 /// if-then branch behind the assignment to ind. This is to ensure that the
296 /// addEltX call generated after is inside the if-then branch.
297 ///    if (tensor[ivs]!=0) {
298 ///      ind = ivs
299 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
300                                       Location loc, Value tensor, Value ind,
301                                       ValueRange ivs) {
302   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
303   Value cond = genIsNonzero(rewriter, loc, val);
304   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
305   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
306   unsigned i = 0;
307   for (auto iv : ivs) {
308     Value idx = constantIndex(rewriter, loc, i++);
309     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
310   }
311   return val;
312 }
313 
314 /// Generates a call that adds one element to a coordinate scheme.
315 /// In particular, this generates code like the following:
316 ///   val = a[i1,..,ik];
317 ///   if val != 0
318 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
319 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
320                           Type eltType, Value ptr, Value val, Value ind,
321                           Value perm) {
322   Location loc = op->getLoc();
323   StringRef name;
324   if (eltType.isF64())
325     name = "addEltF64";
326   else if (eltType.isF32())
327     name = "addEltF32";
328   else if (eltType.isInteger(64))
329     name = "addEltI64";
330   else if (eltType.isInteger(32))
331     name = "addEltI32";
332   else if (eltType.isInteger(16))
333     name = "addEltI16";
334   else if (eltType.isInteger(8))
335     name = "addEltI8";
336   else
337     llvm_unreachable("Unknown element type");
338   SmallVector<Value, 8> params;
339   params.push_back(ptr);
340   params.push_back(val);
341   params.push_back(ind);
342   params.push_back(perm);
343   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
344   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
345   rewriter.create<CallOp>(loc, pTp, fn, params);
346 }
347 
348 /// If the tensor is a sparse constant, generates and returns the pair of
349 /// the constants for the indices and the values.
350 static Optional<std::pair<Value, Value>>
351 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
352                        Value tensor) {
353   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
354     if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
355       DenseElementsAttr indicesAttr = attr.getIndices();
356       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
357       DenseElementsAttr valuesAttr = attr.getValues();
358       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
359       return std::make_pair(indices, values);
360     }
361   }
362   return {};
363 }
364 
365 /// Generates the code to copy the index at indices[ivs] to ind, and return
366 /// the value at value[ivs].
367 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
368                                        Location loc, Value indices,
369                                        Value values, Value ind, ValueRange ivs,
370                                        unsigned rank) {
371   for (unsigned i = 0; i < rank; i++) {
372     Value idx = constantIndex(rewriter, loc, i);
373     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
374                                                    ValueRange{ivs[0], idx});
375     val =
376         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
377     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
378   }
379   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // Conversion rules.
384 //===----------------------------------------------------------------------===//
385 
386 /// Sparse conversion rule for returns.
387 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
388 public:
389   using OpConversionPattern::OpConversionPattern;
390   LogicalResult
391   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
394     return success();
395   }
396 };
397 
398 /// Sparse conversion rule for dimension accesses.
399 class SparseTensorToDimSizeConverter
400     : public OpConversionPattern<tensor::DimOp> {
401 public:
402   using OpConversionPattern::OpConversionPattern;
403   LogicalResult
404   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
405                   ConversionPatternRewriter &rewriter) const override {
406     // Only rewrite annotated DimOp with constant index.
407     auto enc = getSparseTensorEncoding(op.source().getType());
408     if (!enc)
409       return failure();
410     Optional<int64_t> index = op.getConstantIndex();
411     if (!index.hasValue())
412       return failure();
413     // Generate the call.
414     Value src = adaptor.getOperands()[0];
415     int64_t idx = index.getValue();
416     rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
417     return success();
418   }
419 };
420 
421 /// Sparse conversion rule for the new operator.
422 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
423   using OpConversionPattern::OpConversionPattern;
424   LogicalResult
425   matchAndRewrite(NewOp op, OpAdaptor adaptor,
426                   ConversionPatternRewriter &rewriter) const override {
427     Type resType = op.getType();
428     auto enc = getSparseTensorEncoding(resType);
429     if (!enc)
430       return failure();
431     // Generate the call to construct tensor from ptr. The sizes are
432     // inferred from the result type of the new operator.
433     SmallVector<Value, 4> sizes;
434     SmallVector<Value, 8> params;
435     sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
436     Value ptr = adaptor.getOperands()[0];
437     newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
438     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
439     return success();
440   }
441 };
442 
443 /// Sparse conversion rule for the init operator.
444 class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
445   using OpConversionPattern::OpConversionPattern;
446   LogicalResult
447   matchAndRewrite(InitOp op, OpAdaptor adaptor,
448                   ConversionPatternRewriter &rewriter) const override {
449     Type resType = op.getType();
450     auto enc = getSparseTensorEncoding(resType);
451     if (!enc)
452       return failure();
453     // Generate the call to construct empty tensor. The sizes are
454     // explicitly defined by the arguments to the init operator.
455     SmallVector<Value, 8> params;
456     newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
457     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
458     return success();
459   }
460 };
461 
462 /// Sparse conversion rule for the convert operator.
463 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
464   using OpConversionPattern::OpConversionPattern;
465   LogicalResult
466   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
467                   ConversionPatternRewriter &rewriter) const override {
468     Location loc = op->getLoc();
469     Type resType = op.getType();
470     Type srcType = op.source().getType();
471     auto encDst = getSparseTensorEncoding(resType);
472     auto encSrc = getSparseTensorEncoding(srcType);
473     Value src = adaptor.getOperands()[0];
474     if (encDst && encSrc) {
475       // This is a sparse => sparse conversion, which is handled as follows:
476       //   t = src->toCOO();         ; src to COO in dst order
477       //   dst = newSparseTensor(t)
478       // Using the coordinate scheme as an intermediate does not always
479       // yield the fastest conversion but avoids the need for a full
480       // O(N^2) conversion matrix.
481       SmallVector<Value, 4> sizes;
482       SmallVector<Value, 8> params;
483       sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
484                    src);
485       newParams(rewriter, params, op, encDst, kToCOO, sizes, src);
486       Value coo = genNewCall(rewriter, op, params);
487       params[6] = constantI32(rewriter, loc, kFromCOO);
488       params[7] = coo;
489       rewriter.replaceOp(op, genNewCall(rewriter, op, params));
490       return success();
491     }
492     if (!encDst || encSrc) {
493       // TODO: sparse => dense
494       return failure();
495     }
496     // This is a dense => sparse conversion or a sparse constant in COO =>
497     // sparse conversion, which is handled as follows:
498     //   t = newSparseCOO()
499     //   ...code to fill the COO tensor t...
500     //   s = newSparseTensor(t)
501     //
502     // To fill the COO tensor from a dense tensor:
503     //   for i1 in dim1
504     //    ..
505     //     for ik in dimk
506     //       val = a[i1,..,ik]
507     //       if val != 0
508     //         t->add(val, [i1,..,ik], [p1,..,pk])
509     //
510     // To fill the COO tensor from a sparse constant in COO format:
511     //   for i in range(NNZ)
512     //     val = values[i]
513     //     [i1,..,ik] = indices[i]
514     //     t->add(val, [i1,..,ik], [p1,..,pk])
515     //
516     // Note that the dense tensor traversal code is actually implemented
517     // using MLIR IR to avoid having to expose too much low-level
518     // memref traversal details to the runtime support library.
519     // Also note that the code below only generates the "new" ops and
520     // the loop-nest per se; whereas the entire body of the innermost
521     // loop is generated by genAddElt().
522     ShapedType stp = resType.cast<ShapedType>();
523     unsigned rank = stp.getRank();
524     SmallVector<Value, 4> sizes;
525     SmallVector<Value, 8> params;
526     sizesFromSrc(rewriter, sizes, loc, src);
527     newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
528     Value ptr = genNewCall(rewriter, op, params);
529     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
530     Value perm = params[2];
531     SmallVector<Value> lo;
532     SmallVector<Value> hi;
533     SmallVector<Value> st;
534     Value zero = constantIndex(rewriter, loc, 0);
535     Value one = constantIndex(rewriter, loc, 1);
536     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
537     bool isCOOConstant = indicesValues.hasValue();
538     Value indices;
539     Value values;
540     if (isCOOConstant) {
541       indices = indicesValues->first;
542       values = indicesValues->second;
543       lo.push_back(zero);
544       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
545       st.push_back(one);
546     } else {
547       for (unsigned i = 0; i < rank; i++) {
548         lo.push_back(zero);
549         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
550         st.push_back(one);
551       }
552     }
553     Type eltType = stp.getElementType();
554     scf::buildLoopNest(
555         rewriter, op.getLoc(), lo, hi, st, {},
556         [&](OpBuilder &builder, Location loc, ValueRange ivs,
557             ValueRange args) -> scf::ValueVector {
558           Value val;
559           if (isCOOConstant)
560             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
561                                             ivs, rank);
562           else
563             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
564           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
565           return {};
566         });
567     // Final call to construct sparse tensor storage.
568     params[6] = constantI32(rewriter, loc, kFromCOO);
569     params[7] = ptr;
570     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
571     return success();
572   }
573 };
574 
575 /// Sparse conversion rule for the release operator.
576 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
577 public:
578   using OpConversionPattern::OpConversionPattern;
579   LogicalResult
580   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
581                   ConversionPatternRewriter &rewriter) const override {
582     StringRef name = "delSparseTensor";
583     TypeRange none;
584     auto fn = getFunc(op, name, none, adaptor.getOperands());
585     rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
586     rewriter.eraseOp(op);
587     return success();
588   }
589 };
590 
591 /// Sparse conversion rule for pointer accesses.
592 class SparseTensorToPointersConverter
593     : public OpConversionPattern<ToPointersOp> {
594 public:
595   using OpConversionPattern::OpConversionPattern;
596   LogicalResult
597   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
598                   ConversionPatternRewriter &rewriter) const override {
599     Type resType = op.getType();
600     Type eltType = resType.cast<ShapedType>().getElementType();
601     StringRef name;
602     if (eltType.isIndex())
603       name = "sparsePointers";
604     else if (eltType.isInteger(64))
605       name = "sparsePointers64";
606     else if (eltType.isInteger(32))
607       name = "sparsePointers32";
608     else if (eltType.isInteger(16))
609       name = "sparsePointers16";
610     else if (eltType.isInteger(8))
611       name = "sparsePointers8";
612     else
613       return failure();
614     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
615                       /*emitCInterface=*/true);
616     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
617     return success();
618   }
619 };
620 
621 /// Sparse conversion rule for index accesses.
622 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
623 public:
624   using OpConversionPattern::OpConversionPattern;
625   LogicalResult
626   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
627                   ConversionPatternRewriter &rewriter) const override {
628     Type resType = op.getType();
629     Type eltType = resType.cast<ShapedType>().getElementType();
630     StringRef name;
631     if (eltType.isIndex())
632       name = "sparseIndices";
633     else if (eltType.isInteger(64))
634       name = "sparseIndices64";
635     else if (eltType.isInteger(32))
636       name = "sparseIndices32";
637     else if (eltType.isInteger(16))
638       name = "sparseIndices16";
639     else if (eltType.isInteger(8))
640       name = "sparseIndices8";
641     else
642       return failure();
643     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
644                       /*emitCInterface=*/true);
645     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
646     return success();
647   }
648 };
649 
650 /// Sparse conversion rule for value accesses.
651 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
652 public:
653   using OpConversionPattern::OpConversionPattern;
654   LogicalResult
655   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
656                   ConversionPatternRewriter &rewriter) const override {
657     Type resType = op.getType();
658     Type eltType = resType.cast<ShapedType>().getElementType();
659     StringRef name;
660     if (eltType.isF64())
661       name = "sparseValuesF64";
662     else if (eltType.isF32())
663       name = "sparseValuesF32";
664     else if (eltType.isInteger(64))
665       name = "sparseValuesI64";
666     else if (eltType.isInteger(32))
667       name = "sparseValuesI32";
668     else if (eltType.isInteger(16))
669       name = "sparseValuesI16";
670     else if (eltType.isInteger(8))
671       name = "sparseValuesI8";
672     else
673       return failure();
674     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
675                       /*emitCInterface=*/true);
676     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
677     return success();
678   }
679 };
680 
681 /// Sparse conversion rule for tensor reconstruction.
682 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
683 public:
684   using OpConversionPattern::OpConversionPattern;
685   LogicalResult
686   // Simply fold the operator into the pointer to the sparse storage scheme.
687   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
688                   ConversionPatternRewriter &rewriter) const override {
689     // Check that all arguments of the tensor reconstruction operators are calls
690     // into the support library that query exactly the same opaque pointer.
691     Value ptr;
692     for (Value op : adaptor.getOperands()) {
693       if (auto call = op.getDefiningOp<CallOp>()) {
694         Value arg = call.getOperand(0);
695         if (!arg.getType().isa<LLVM::LLVMPointerType>())
696           return failure();
697         if (!ptr)
698           ptr = arg;
699         else if (arg != ptr)
700           return failure();
701       }
702     }
703     // If a single opaque pointer is found, perform the folding.
704     if (!ptr)
705       return failure();
706     rewriter.replaceOp(op, ptr);
707     return success();
708   }
709 };
710 
711 } // namespace
712 
713 //===----------------------------------------------------------------------===//
714 // Public method for populating conversion rules.
715 //===----------------------------------------------------------------------===//
716 
717 /// Populates the given patterns list with conversion rules required for
718 /// the sparsification of linear algebra operations.
719 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
720                                                   RewritePatternSet &patterns) {
721   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
722                SparseTensorNewConverter, SparseTensorInitConverter,
723                SparseTensorConvertConverter, SparseTensorReleaseConverter,
724                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
725                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
726       typeConverter, patterns.getContext());
727 }
728