1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 /// New tensor storage action. Keep these values consistent with
33 /// the sparse runtime support library.
34 enum Action : uint32_t {
35   kEmpty = 0,
36   kFromFile = 1,
37   kFromCOO = 2,
38   kEmptyCOO = 3,
39   kToCOO = 4
40 };
41 
42 //===----------------------------------------------------------------------===//
43 // Helper methods.
44 //===----------------------------------------------------------------------===//
45 
46 /// Returns internal type encoding for primary storage. Keep these
47 /// values consistent with the sparse runtime support library.
48 static uint32_t getPrimaryTypeEncoding(Type tp) {
49   if (tp.isF64())
50     return 1;
51   if (tp.isF32())
52     return 2;
53   if (tp.isInteger(64))
54     return 3;
55   if (tp.isInteger(32))
56     return 4;
57   if (tp.isInteger(16))
58     return 5;
59   if (tp.isInteger(8))
60     return 6;
61   return 0;
62 }
63 
64 /// Returns internal type encoding for overhead storage. Keep these
65 /// values consistent with the sparse runtime support library.
66 static uint32_t getOverheadTypeEncoding(unsigned width) {
67   switch (width) {
68   default:
69     return 1;
70   case 32:
71     return 2;
72   case 16:
73     return 3;
74   case 8:
75     return 4;
76   }
77 }
78 
79 /// Returns internal dimension level type encoding. Keep these
80 /// values consistent with the sparse runtime support library.
81 static uint32_t
82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
83   switch (dlt) {
84   case SparseTensorEncodingAttr::DimLevelType::Dense:
85     return 0;
86   case SparseTensorEncodingAttr::DimLevelType::Compressed:
87     return 1;
88   case SparseTensorEncodingAttr::DimLevelType::Singleton:
89     return 2;
90   }
91   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
92 }
93 
94 /// Generates a constant zero of the given type.
95 inline static Value constantZero(ConversionPatternRewriter &rewriter,
96                                  Location loc, Type t) {
97   return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
98 }
99 
100 /// Generates a constant of `index` type.
101 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
102                                   Location loc, int64_t i) {
103   return rewriter.create<arith::ConstantIndexOp>(loc, i);
104 }
105 
106 /// Generates a constant of `i32` type.
107 inline static Value constantI32(ConversionPatternRewriter &rewriter,
108                                 Location loc, int32_t i) {
109   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
110 }
111 
112 /// Generates a constant of `i8` type.
113 inline static Value constantI8(ConversionPatternRewriter &rewriter,
114                                Location loc, int8_t i) {
115   return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
116 }
117 
118 /// Returns a function reference (first hit also inserts into module). Sets
119 /// the "_emit_c_interface" on the function declaration when requested,
120 /// so that LLVM lowering generates a wrapper function that takes care
121 /// of ABI complications with passing in and returning MemRefs to C functions.
122 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
123                                  TypeRange resultType, ValueRange operands,
124                                  bool emitCInterface = false) {
125   MLIRContext *context = op->getContext();
126   auto module = op->getParentOfType<ModuleOp>();
127   auto result = SymbolRefAttr::get(context, name);
128   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
129   if (!func) {
130     OpBuilder moduleBuilder(module.getBodyRegion());
131     func = moduleBuilder.create<FuncOp>(
132         op->getLoc(), name,
133         FunctionType::get(context, operands.getTypes(), resultType));
134     func.setPrivate();
135     if (emitCInterface)
136       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
137   }
138   return result;
139 }
140 
141 /// Generates dimension size call.
142 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
143                             SparseTensorEncodingAttr &enc, Value src,
144                             int64_t idx) {
145   // Permute the index according to an optional dimension ordering.
146   if (AffineMap p = enc.getDimOrdering())
147     idx = p.getPermutedPosition(idx);
148   // Generate the call.
149   Location loc = op->getLoc();
150   StringRef name = "sparseDimSize";
151   SmallVector<Value, 2> params;
152   params.push_back(src);
153   params.push_back(constantIndex(rewriter, loc, idx));
154   Type iTp = rewriter.getIndexType();
155   auto fn = getFunc(op, name, iTp, params);
156   return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
157 }
158 
159 /// Generates a call into the "swiss army knife" method of the sparse runtime
160 /// support library for materializing sparse tensors into the computation.
161 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
162                         ArrayRef<Value> params) {
163   Location loc = op->getLoc();
164   StringRef name = "newSparseTensor";
165   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
166   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
167   auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
168   return call.getResult(0);
169 }
170 
171 /// Populates given sizes array from type.
172 static void sizesFromType(ConversionPatternRewriter &rewriter,
173                           SmallVector<Value, 4> &sizes, Location loc,
174                           ShapedType stp) {
175   auto shape = stp.getShape();
176   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
177     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
178     sizes.push_back(constantIndex(rewriter, loc, s));
179   }
180 }
181 
182 /// Populates given sizes array from source.
183 static void sizesFromSrc(ConversionPatternRewriter &rewriter,
184                          SmallVector<Value, 4> &sizes, Location loc,
185                          Value src) {
186   ShapedType stp = src.getType().cast<ShapedType>();
187   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
188     sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
189 }
190 
191 /// Populates given sizes array from type (for static sizes) and from
192 /// an already converted into opague pointer source (for dynamic sizes).
193 static void sizesFromPtr(ConversionPatternRewriter &rewriter,
194                          SmallVector<Value, 4> &sizes, Operation *op,
195                          SparseTensorEncodingAttr &enc, ShapedType stp,
196                          Value src) {
197   auto shape = stp.getShape();
198   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
199     if (shape[i] == ShapedType::kDynamicSize)
200       sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
201     else
202       sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
203 }
204 
205 /// Generates a temporary buffer of the given size and type.
206 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
207                        unsigned sz, Type tp) {
208   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
209   Value a = constantIndex(rewriter, loc, sz);
210   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
211 }
212 
213 /// Generates a temporary buffer of the given type and given contents.
214 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
215                        ArrayRef<Value> values) {
216   unsigned sz = values.size();
217   assert(sz >= 1);
218   Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
219   for (unsigned i = 0; i < sz; i++) {
220     Value idx = constantIndex(rewriter, loc, i);
221     rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
222   }
223   return buffer;
224 }
225 
226 /// Populates parameters required to call the "swiss army knife" method of the
227 /// sparse runtime support library for materializing sparse tensors into the
228 /// computation.
229 static void newParams(ConversionPatternRewriter &rewriter,
230                       SmallVector<Value, 8> &params, Operation *op,
231                       SparseTensorEncodingAttr &enc, uint32_t action,
232                       ValueRange szs, Value ptr = Value()) {
233   Location loc = op->getLoc();
234   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
235   unsigned sz = dlt.size();
236   // Sparsity annotations.
237   SmallVector<Value, 4> attrs;
238   for (unsigned i = 0; i < sz; i++)
239     attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
240   params.push_back(genBuffer(rewriter, loc, attrs));
241   // Dimension sizes array of the enveloping tensor. Useful for either
242   // verification of external data, or for construction of internal data.
243   SmallVector<Value, 4> sizes;
244   for (Value s : szs)
245     sizes.push_back(s);
246   params.push_back(genBuffer(rewriter, loc, sizes));
247   // Dimension order permutation array. This is the "identity" permutation by
248   // default, or otherwise the "reverse" permutation of a given ordering, so
249   // that indices can be mapped quickly to the right position.
250   SmallVector<Value, 4> rev(sz);
251   if (AffineMap p = enc.getDimOrdering()) {
252     for (unsigned i = 0; i < sz; i++)
253       rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i);
254   } else {
255     for (unsigned i = 0; i < sz; i++)
256       rev[i] = constantIndex(rewriter, loc, i);
257   }
258   params.push_back(genBuffer(rewriter, loc, rev));
259   // Secondary and primary types encoding.
260   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
261   uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
262   uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
263   uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
264   assert(primary);
265   params.push_back(constantI32(rewriter, loc, secPtr));
266   params.push_back(constantI32(rewriter, loc, secInd));
267   params.push_back(constantI32(rewriter, loc, primary));
268   // User action and pointer.
269   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
270   if (!ptr)
271     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
272   params.push_back(constantI32(rewriter, loc, action));
273   params.push_back(ptr);
274 }
275 
276 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
277 /// For floating types, we use the "unordered" comparator (i.e., returns
278 /// true if `v` is NaN).
279 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
280                           Value v) {
281   Type t = v.getType();
282   Value zero = constantZero(rewriter, loc, t);
283   if (t.isa<FloatType>())
284     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
285                                           zero);
286   if (t.isIntOrIndex())
287     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
288                                           zero);
289   llvm_unreachable("Unknown element type");
290 }
291 
292 /// Generates the code to read the value from tensor[ivs], and conditionally
293 /// stores the indices ivs to the memory in ind. The generated code looks like
294 /// the following and the insertion point after this routine is inside the
295 /// if-then branch behind the assignment to ind. This is to ensure that the
296 /// addEltX call generated after is inside the if-then branch.
297 ///    if (tensor[ivs]!=0) {
298 ///      ind = ivs
299 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
300                                       Location loc, Value tensor, Value ind,
301                                       ValueRange ivs) {
302   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
303   Value cond = genIsNonzero(rewriter, loc, val);
304   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
305   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
306   unsigned i = 0;
307   for (auto iv : ivs) {
308     Value idx = constantIndex(rewriter, loc, i++);
309     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
310   }
311   return val;
312 }
313 
314 /// Generates a call that adds one element to a coordinate scheme.
315 /// In particular, this generates code like the following:
316 ///   val = a[i1,..,ik];
317 ///   if val != 0
318 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
319 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
320                           Type eltType, Value ptr, Value val, Value ind,
321                           Value perm) {
322   Location loc = op->getLoc();
323   StringRef name;
324   if (eltType.isF64())
325     name = "addEltF64";
326   else if (eltType.isF32())
327     name = "addEltF32";
328   else if (eltType.isInteger(64))
329     name = "addEltI64";
330   else if (eltType.isInteger(32))
331     name = "addEltI32";
332   else if (eltType.isInteger(16))
333     name = "addEltI16";
334   else if (eltType.isInteger(8))
335     name = "addEltI8";
336   else
337     llvm_unreachable("Unknown element type");
338   SmallVector<Value, 8> params;
339   params.push_back(ptr);
340   params.push_back(val);
341   params.push_back(ind);
342   params.push_back(perm);
343   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
344   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
345   rewriter.create<CallOp>(loc, pTp, fn, params);
346 }
347 
348 /// If the tensor is a sparse constant, generates and returns the pair of
349 /// the constants for the indices and the values.
350 static Optional<std::pair<Value, Value>>
351 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
352                        Value tensor) {
353   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
354     if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
355       DenseElementsAttr indicesAttr = attr.getIndices();
356       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
357       DenseElementsAttr valuesAttr = attr.getValues();
358       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
359       return std::make_pair(indices, values);
360     }
361   }
362   return {};
363 }
364 
365 /// Generates the code to copy the index at indices[ivs] to ind, and return
366 /// the value at value[ivs].
367 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
368                                        Location loc, Value indices,
369                                        Value values, Value ind, ValueRange ivs,
370                                        unsigned rank) {
371   for (unsigned i = 0; i < rank; i++) {
372     Value idx = constantIndex(rewriter, loc, i);
373     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
374                                                    ValueRange{ivs[0], idx});
375     val =
376         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
377     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
378   }
379   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // Conversion rules.
384 //===----------------------------------------------------------------------===//
385 
386 /// Sparse conversion rule for returns.
387 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
388 public:
389   using OpConversionPattern::OpConversionPattern;
390   LogicalResult
391   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
394     return success();
395   }
396 };
397 
398 /// Sparse conversion rule for dimension accesses.
399 class SparseTensorToDimSizeConverter
400     : public OpConversionPattern<tensor::DimOp> {
401 public:
402   using OpConversionPattern::OpConversionPattern;
403   LogicalResult
404   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
405                   ConversionPatternRewriter &rewriter) const override {
406     // Only rewrite annotated DimOp with constant index.
407     auto enc = getSparseTensorEncoding(op.source().getType());
408     if (!enc)
409       return failure();
410     Optional<int64_t> index = op.getConstantIndex();
411     if (!index.hasValue())
412       return failure();
413     // Generate the call.
414     Value src = adaptor.getOperands()[0];
415     int64_t idx = index.getValue();
416     rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
417     return success();
418   }
419 };
420 
421 /// Sparse conversion rule for trivial tensor casts.
422 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
423   using OpConversionPattern::OpConversionPattern;
424   LogicalResult
425   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
426                   ConversionPatternRewriter &rewriter) const override {
427     // Only rewrite identically annotated source/dest.
428     auto encDst = getSparseTensorEncoding(op.getType());
429     auto encSrc = getSparseTensorEncoding(op.source().getType());
430     if (!encDst || encDst != encSrc)
431       return failure();
432     rewriter.replaceOp(op, adaptor.getOperands());
433     return success();
434   }
435 };
436 
437 /// Sparse conversion rule for the new operator.
438 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
439   using OpConversionPattern::OpConversionPattern;
440   LogicalResult
441   matchAndRewrite(NewOp op, OpAdaptor adaptor,
442                   ConversionPatternRewriter &rewriter) const override {
443     Type resType = op.getType();
444     auto enc = getSparseTensorEncoding(resType);
445     if (!enc)
446       return failure();
447     // Generate the call to construct tensor from ptr. The sizes are
448     // inferred from the result type of the new operator.
449     SmallVector<Value, 4> sizes;
450     SmallVector<Value, 8> params;
451     sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
452     Value ptr = adaptor.getOperands()[0];
453     newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
454     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
455     return success();
456   }
457 };
458 
459 /// Sparse conversion rule for the init operator.
460 class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
461   using OpConversionPattern::OpConversionPattern;
462   LogicalResult
463   matchAndRewrite(InitOp op, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     Type resType = op.getType();
466     auto enc = getSparseTensorEncoding(resType);
467     if (!enc)
468       return failure();
469     // Generate the call to construct empty tensor. The sizes are
470     // explicitly defined by the arguments to the init operator.
471     SmallVector<Value, 8> params;
472     newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
473     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
474     return success();
475   }
476 };
477 
478 /// Sparse conversion rule for the convert operator.
479 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
480   using OpConversionPattern::OpConversionPattern;
481   LogicalResult
482   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
483                   ConversionPatternRewriter &rewriter) const override {
484     Location loc = op->getLoc();
485     Type resType = op.getType();
486     Type srcType = op.source().getType();
487     auto encDst = getSparseTensorEncoding(resType);
488     auto encSrc = getSparseTensorEncoding(srcType);
489     Value src = adaptor.getOperands()[0];
490     if (encDst && encSrc) {
491       // This is a sparse => sparse conversion, which is handled as follows:
492       //   t = src->toCOO();         ; src to COO in dst order
493       //   dst = newSparseTensor(t)
494       // Using the coordinate scheme as an intermediate does not always
495       // yield the fastest conversion but avoids the need for a full
496       // O(N^2) conversion matrix.
497       if (encDst == encSrc) {
498         rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
499         return success();
500       }
501       SmallVector<Value, 4> sizes;
502       SmallVector<Value, 8> params;
503       sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
504                    src);
505       newParams(rewriter, params, op, encDst, kToCOO, sizes, src);
506       Value coo = genNewCall(rewriter, op, params);
507       params[6] = constantI32(rewriter, loc, kFromCOO);
508       params[7] = coo;
509       rewriter.replaceOp(op, genNewCall(rewriter, op, params));
510       return success();
511     }
512     if (!encDst || encSrc) {
513       // TODO: sparse => dense
514       return failure();
515     }
516     // This is a dense => sparse conversion or a sparse constant in COO =>
517     // sparse conversion, which is handled as follows:
518     //   t = newSparseCOO()
519     //   ...code to fill the COO tensor t...
520     //   s = newSparseTensor(t)
521     //
522     // To fill the COO tensor from a dense tensor:
523     //   for i1 in dim1
524     //    ..
525     //     for ik in dimk
526     //       val = a[i1,..,ik]
527     //       if val != 0
528     //         t->add(val, [i1,..,ik], [p1,..,pk])
529     //
530     // To fill the COO tensor from a sparse constant in COO format:
531     //   for i in range(NNZ)
532     //     val = values[i]
533     //     [i1,..,ik] = indices[i]
534     //     t->add(val, [i1,..,ik], [p1,..,pk])
535     //
536     // Note that the dense tensor traversal code is actually implemented
537     // using MLIR IR to avoid having to expose too much low-level
538     // memref traversal details to the runtime support library.
539     // Also note that the code below only generates the "new" ops and
540     // the loop-nest per se; whereas the entire body of the innermost
541     // loop is generated by genAddElt().
542     ShapedType stp = resType.cast<ShapedType>();
543     unsigned rank = stp.getRank();
544     SmallVector<Value, 4> sizes;
545     SmallVector<Value, 8> params;
546     sizesFromSrc(rewriter, sizes, loc, src);
547     newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
548     Value ptr = genNewCall(rewriter, op, params);
549     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
550     Value perm = params[2];
551     SmallVector<Value> lo;
552     SmallVector<Value> hi;
553     SmallVector<Value> st;
554     Value zero = constantIndex(rewriter, loc, 0);
555     Value one = constantIndex(rewriter, loc, 1);
556     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
557     bool isCOOConstant = indicesValues.hasValue();
558     Value indices;
559     Value values;
560     if (isCOOConstant) {
561       indices = indicesValues->first;
562       values = indicesValues->second;
563       lo.push_back(zero);
564       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
565       st.push_back(one);
566     } else {
567       for (unsigned i = 0; i < rank; i++) {
568         lo.push_back(zero);
569         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
570         st.push_back(one);
571       }
572     }
573     Type eltType = stp.getElementType();
574     scf::buildLoopNest(
575         rewriter, op.getLoc(), lo, hi, st, {},
576         [&](OpBuilder &builder, Location loc, ValueRange ivs,
577             ValueRange args) -> scf::ValueVector {
578           Value val;
579           if (isCOOConstant)
580             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
581                                             ivs, rank);
582           else
583             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
584           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
585           return {};
586         });
587     // Final call to construct sparse tensor storage.
588     params[6] = constantI32(rewriter, loc, kFromCOO);
589     params[7] = ptr;
590     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
591     return success();
592   }
593 };
594 
595 /// Sparse conversion rule for the release operator.
596 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
597 public:
598   using OpConversionPattern::OpConversionPattern;
599   LogicalResult
600   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
601                   ConversionPatternRewriter &rewriter) const override {
602     StringRef name = "delSparseTensor";
603     TypeRange none;
604     auto fn = getFunc(op, name, none, adaptor.getOperands());
605     rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
606     rewriter.eraseOp(op);
607     return success();
608   }
609 };
610 
611 /// Sparse conversion rule for pointer accesses.
612 class SparseTensorToPointersConverter
613     : public OpConversionPattern<ToPointersOp> {
614 public:
615   using OpConversionPattern::OpConversionPattern;
616   LogicalResult
617   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
618                   ConversionPatternRewriter &rewriter) const override {
619     Type resType = op.getType();
620     Type eltType = resType.cast<ShapedType>().getElementType();
621     StringRef name;
622     if (eltType.isIndex())
623       name = "sparsePointers";
624     else if (eltType.isInteger(64))
625       name = "sparsePointers64";
626     else if (eltType.isInteger(32))
627       name = "sparsePointers32";
628     else if (eltType.isInteger(16))
629       name = "sparsePointers16";
630     else if (eltType.isInteger(8))
631       name = "sparsePointers8";
632     else
633       return failure();
634     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
635                       /*emitCInterface=*/true);
636     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
637     return success();
638   }
639 };
640 
641 /// Sparse conversion rule for index accesses.
642 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
643 public:
644   using OpConversionPattern::OpConversionPattern;
645   LogicalResult
646   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
647                   ConversionPatternRewriter &rewriter) const override {
648     Type resType = op.getType();
649     Type eltType = resType.cast<ShapedType>().getElementType();
650     StringRef name;
651     if (eltType.isIndex())
652       name = "sparseIndices";
653     else if (eltType.isInteger(64))
654       name = "sparseIndices64";
655     else if (eltType.isInteger(32))
656       name = "sparseIndices32";
657     else if (eltType.isInteger(16))
658       name = "sparseIndices16";
659     else if (eltType.isInteger(8))
660       name = "sparseIndices8";
661     else
662       return failure();
663     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
664                       /*emitCInterface=*/true);
665     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
666     return success();
667   }
668 };
669 
670 /// Sparse conversion rule for value accesses.
671 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
672 public:
673   using OpConversionPattern::OpConversionPattern;
674   LogicalResult
675   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
676                   ConversionPatternRewriter &rewriter) const override {
677     Type resType = op.getType();
678     Type eltType = resType.cast<ShapedType>().getElementType();
679     StringRef name;
680     if (eltType.isF64())
681       name = "sparseValuesF64";
682     else if (eltType.isF32())
683       name = "sparseValuesF32";
684     else if (eltType.isInteger(64))
685       name = "sparseValuesI64";
686     else if (eltType.isInteger(32))
687       name = "sparseValuesI32";
688     else if (eltType.isInteger(16))
689       name = "sparseValuesI16";
690     else if (eltType.isInteger(8))
691       name = "sparseValuesI8";
692     else
693       return failure();
694     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
695                       /*emitCInterface=*/true);
696     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
697     return success();
698   }
699 };
700 
701 /// Sparse conversion rule for tensor reconstruction.
702 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
703 public:
704   using OpConversionPattern::OpConversionPattern;
705   LogicalResult
706   // Simply fold the operator into the pointer to the sparse storage scheme.
707   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
708                   ConversionPatternRewriter &rewriter) const override {
709     // Check that all arguments of the tensor reconstruction operators are calls
710     // into the support library that query exactly the same opaque pointer.
711     Value ptr;
712     for (Value op : adaptor.getOperands()) {
713       if (auto call = op.getDefiningOp<CallOp>()) {
714         Value arg = call.getOperand(0);
715         if (!arg.getType().isa<LLVM::LLVMPointerType>())
716           return failure();
717         if (!ptr)
718           ptr = arg;
719         else if (arg != ptr)
720           return failure();
721       }
722     }
723     // If a single opaque pointer is found, perform the folding.
724     if (!ptr)
725       return failure();
726     rewriter.replaceOp(op, ptr);
727     return success();
728   }
729 };
730 
731 } // namespace
732 
733 //===----------------------------------------------------------------------===//
734 // Public method for populating conversion rules.
735 //===----------------------------------------------------------------------===//
736 
737 /// Populates the given patterns list with conversion rules required for
738 /// the sparsification of linear algebra operations.
739 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
740                                                   RewritePatternSet &patterns) {
741   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
742                SparseCastConverter, SparseTensorNewConverter,
743                SparseTensorInitConverter, SparseTensorConvertConverter,
744                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
745                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
746                SparseTensorToTensorConverter>(typeConverter,
747                                               patterns.getContext());
748 }
749