1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::sparse_tensor;
29 
30 namespace {
31 
32 /// New tensor storage action. Keep these values consistent with
33 /// the sparse runtime support library.
34 enum Action : uint32_t {
35   kEmpty = 0,
36   kFromFile = 1,
37   kFromCOO = 2,
38   kEmptyCOO = 3,
39   kToCOO = 4
40 };
41 
42 //===----------------------------------------------------------------------===//
43 // Helper methods.
44 //===----------------------------------------------------------------------===//
45 
46 /// Returns internal type encoding for primary storage. Keep these
47 /// values consistent with the sparse runtime support library.
48 static unsigned getPrimaryTypeEncoding(Type tp) {
49   if (tp.isF64())
50     return 1;
51   if (tp.isF32())
52     return 2;
53   if (tp.isInteger(64))
54     return 3;
55   if (tp.isInteger(32))
56     return 4;
57   if (tp.isInteger(16))
58     return 5;
59   if (tp.isInteger(8))
60     return 6;
61   return 0;
62 }
63 
64 /// Returns internal type encoding for overhead storage. Keep these
65 /// values consistent with the sparse runtime support library.
66 static unsigned getOverheadTypeEncoding(unsigned width) {
67   switch (width) {
68   default:
69     return 1;
70   case 32:
71     return 2;
72   case 16:
73     return 3;
74   case 8:
75     return 4;
76   }
77 }
78 
79 /// Returns internal dimension level type encoding. Keep these
80 /// values consistent with the sparse runtime support library.
81 static unsigned
82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
83   switch (dlt) {
84   case SparseTensorEncodingAttr::DimLevelType::Dense:
85     return 0;
86   case SparseTensorEncodingAttr::DimLevelType::Compressed:
87     return 1;
88   case SparseTensorEncodingAttr::DimLevelType::Singleton:
89     return 2;
90   }
91   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
92 }
93 
94 /// Generates a constant zero of the given type.
95 inline static Value constantZero(ConversionPatternRewriter &rewriter,
96                                  Location loc, Type t) {
97   return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
98 }
99 
100 /// Generates a constant of `index` type.
101 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
102                                   Location loc, int64_t i) {
103   return rewriter.create<arith::ConstantIndexOp>(loc, i);
104 }
105 
106 /// Generates a constant of `i64` type.
107 inline static Value constantI64(ConversionPatternRewriter &rewriter,
108                                 Location loc, int64_t i) {
109   return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
110 }
111 
112 /// Generates a constant of `i32` type.
113 inline static Value constantI32(ConversionPatternRewriter &rewriter,
114                                 Location loc, int32_t i) {
115   return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
116 }
117 
118 /// Generates a constant of `i8` type.
119 inline static Value constantI8(ConversionPatternRewriter &rewriter,
120                                Location loc, int8_t i) {
121   return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
122 }
123 
124 /// Returns a function reference (first hit also inserts into module). Sets
125 /// the "_emit_c_interface" on the function declaration when requested,
126 /// so that LLVM lowering generates a wrapper function that takes care
127 /// of ABI complications with passing in and returning MemRefs to C functions.
128 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
129                                  TypeRange resultType, ValueRange operands,
130                                  bool emitCInterface = false) {
131   MLIRContext *context = op->getContext();
132   auto module = op->getParentOfType<ModuleOp>();
133   auto result = SymbolRefAttr::get(context, name);
134   auto func = module.lookupSymbol<FuncOp>(result.getAttr());
135   if (!func) {
136     OpBuilder moduleBuilder(module.getBodyRegion());
137     func = moduleBuilder.create<FuncOp>(
138         op->getLoc(), name,
139         FunctionType::get(context, operands.getTypes(), resultType));
140     func.setPrivate();
141     if (emitCInterface)
142       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
143   }
144   return result;
145 }
146 
147 /// Generates dimension size call.
148 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
149                             SparseTensorEncodingAttr &enc, Value src,
150                             int64_t idx) {
151   // Permute the index according to an optional dimension ordering.
152   if (AffineMap p = enc.getDimOrdering())
153     idx = p.getPermutedPosition(idx);
154   // Generate the call.
155   Location loc = op->getLoc();
156   StringRef name = "sparseDimSize";
157   SmallVector<Value, 2> params;
158   params.push_back(src);
159   params.push_back(constantIndex(rewriter, loc, idx));
160   Type iTp = rewriter.getIndexType();
161   auto fn = getFunc(op, name, iTp, params);
162   return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
163 }
164 
165 /// Generates a call into the "swiss army knife" method of the sparse runtime
166 /// support library for materializing sparse tensors into the computation.
167 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
168                         ArrayRef<Value> params) {
169   Location loc = op->getLoc();
170   StringRef name = "newSparseTensor";
171   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
172   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
173   auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
174   return call.getResult(0);
175 }
176 
177 /// Populates given sizes array from type.
178 static void sizesFromType(ConversionPatternRewriter &rewriter,
179                           SmallVector<Value, 4> &sizes, Location loc,
180                           ShapedType stp) {
181   auto shape = stp.getShape();
182   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
183     uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
184     sizes.push_back(constantIndex(rewriter, loc, s));
185   }
186 }
187 
188 /// Populates given sizes array from source.
189 static void sizesFromSrc(ConversionPatternRewriter &rewriter,
190                          SmallVector<Value, 4> &sizes, Location loc,
191                          Value src) {
192   ShapedType stp = src.getType().cast<ShapedType>();
193   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
194     sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
195 }
196 
197 /// Populates given sizes array from type (for static sizes) and from
198 /// an already converted into opague pointer source (for dynamic sizes).
199 static void sizesFromPtr(ConversionPatternRewriter &rewriter,
200                          SmallVector<Value, 4> &sizes, Operation *op,
201                          SparseTensorEncodingAttr &enc, ShapedType stp,
202                          Value src) {
203   auto shape = stp.getShape();
204   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
205     if (shape[i] == ShapedType::kDynamicSize)
206       sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
207     else
208       sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
209 }
210 
211 /// Generates a temporary buffer of the given size and type.
212 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
213                        unsigned sz, Type tp) {
214   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
215   Value a = constantIndex(rewriter, loc, sz);
216   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
217 }
218 
219 /// Generates a temporary buffer of the given type and given contents.
220 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
221                        ArrayRef<Value> values) {
222   unsigned sz = values.size();
223   assert(sz >= 1);
224   Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
225   for (unsigned i = 0; i < sz; i++) {
226     Value idx = constantIndex(rewriter, loc, i);
227     rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
228   }
229   return buffer;
230 }
231 
232 /// Populates parameters required to call the "swiss army knife" method of the
233 /// sparse runtime support library for materializing sparse tensors into the
234 /// computation.
235 static void newParams(ConversionPatternRewriter &rewriter,
236                       SmallVector<Value, 8> &params, Operation *op,
237                       SparseTensorEncodingAttr &enc, uint32_t action,
238                       ValueRange szs, Value ptr = Value()) {
239   Location loc = op->getLoc();
240   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
241   unsigned sz = dlt.size();
242   // Sparsity annotations.
243   SmallVector<Value, 4> attrs;
244   for (unsigned i = 0; i < sz; i++)
245     attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
246   params.push_back(genBuffer(rewriter, loc, attrs));
247   // Dimension sizes array of the enveloping tensor. Useful for either
248   // verification of external data, or for construction of internal data.
249   // The index type is casted to I64 for API consistency.
250   Type iTp = rewriter.getI64Type();
251   SmallVector<Value, 4> sizes;
252   for (Value s : szs)
253     sizes.push_back(rewriter.create<arith::IndexCastOp>(loc, s, iTp));
254   params.push_back(genBuffer(rewriter, loc, sizes));
255   // Dimension order permutation array. This is the "identity" permutation by
256   // default, or otherwise the "reverse" permutation of a given ordering, so
257   // that indices can be mapped quickly to the right position.
258   SmallVector<Value, 4> rev(sz);
259   if (AffineMap p = enc.getDimOrdering()) {
260     for (unsigned i = 0; i < sz; i++)
261       rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i);
262   } else {
263     for (unsigned i = 0; i < sz; i++)
264       rev[i] = constantI64(rewriter, loc, i);
265   }
266   params.push_back(genBuffer(rewriter, loc, rev));
267   // Secondary and primary types encoding.
268   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
269   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
270   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
271   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
272   assert(primary);
273   params.push_back(constantI64(rewriter, loc, secPtr));
274   params.push_back(constantI64(rewriter, loc, secInd));
275   params.push_back(constantI64(rewriter, loc, primary));
276   // User action and pointer.
277   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
278   if (!ptr)
279     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
280   params.push_back(constantI32(rewriter, loc, action));
281   params.push_back(ptr);
282 }
283 
284 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
285 /// For floating types, we use the "unordered" comparator (i.e., returns
286 /// true if `v` is NaN).
287 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
288                           Value v) {
289   Type t = v.getType();
290   Value zero = constantZero(rewriter, loc, t);
291   if (t.isa<FloatType>())
292     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
293                                           zero);
294   if (t.isIntOrIndex())
295     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
296                                           zero);
297   llvm_unreachable("Unknown element type");
298 }
299 
300 /// Generates the code to read the value from tensor[ivs], and conditionally
301 /// stores the indices ivs to the memory in ind. The generated code looks like
302 /// the following and the insertion point after this routine is inside the
303 /// if-then branch behind the assignment to ind. This is to ensure that the
304 /// addEltX call generated after is inside the if-then branch.
305 ///    if (tensor[ivs]!=0) {
306 ///      ind = ivs
307 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
308                                       Location loc, Value tensor, Value ind,
309                                       ValueRange ivs) {
310   Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
311   Value cond = genIsNonzero(rewriter, loc, val);
312   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
313   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
314   unsigned i = 0;
315   for (auto iv : ivs) {
316     Value idx = constantIndex(rewriter, loc, i++);
317     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
318   }
319   return val;
320 }
321 
322 /// Generates a call that adds one element to a coordinate scheme.
323 /// In particular, this generates code like the following:
324 ///   val = a[i1,..,ik];
325 ///   if val != 0
326 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
327 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
328                           Type eltType, Value ptr, Value val, Value ind,
329                           Value perm) {
330   Location loc = op->getLoc();
331   StringRef name;
332   if (eltType.isF64())
333     name = "addEltF64";
334   else if (eltType.isF32())
335     name = "addEltF32";
336   else if (eltType.isInteger(64))
337     name = "addEltI64";
338   else if (eltType.isInteger(32))
339     name = "addEltI32";
340   else if (eltType.isInteger(16))
341     name = "addEltI16";
342   else if (eltType.isInteger(8))
343     name = "addEltI8";
344   else
345     llvm_unreachable("Unknown element type");
346   SmallVector<Value, 8> params;
347   params.push_back(ptr);
348   params.push_back(val);
349   params.push_back(ind);
350   params.push_back(perm);
351   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
352   auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
353   rewriter.create<CallOp>(loc, pTp, fn, params);
354 }
355 
356 /// If the tensor is a sparse constant, generates and returns the pair of
357 /// the constants for the indices and the values.
358 static Optional<std::pair<Value, Value>>
359 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
360                        Value tensor) {
361   if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
362     if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
363       DenseElementsAttr indicesAttr = attr.getIndices();
364       Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
365       DenseElementsAttr valuesAttr = attr.getValues();
366       Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr);
367       return std::make_pair(indices, values);
368     }
369   }
370   return {};
371 }
372 
373 /// Generates the code to copy the index at indices[ivs] to ind, and return
374 /// the value at value[ivs].
375 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
376                                        Location loc, Value indices,
377                                        Value values, Value ind, ValueRange ivs,
378                                        unsigned rank) {
379   for (unsigned i = 0; i < rank; i++) {
380     Value idx = constantIndex(rewriter, loc, i);
381     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
382                                                    ValueRange{ivs[0], idx});
383     val =
384         rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
385     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
386   }
387   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // Conversion rules.
392 //===----------------------------------------------------------------------===//
393 
394 /// Sparse conversion rule for returns.
395 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
396 public:
397   using OpConversionPattern::OpConversionPattern;
398   LogicalResult
399   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
400                   ConversionPatternRewriter &rewriter) const override {
401     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
402     return success();
403   }
404 };
405 
406 /// Sparse conversion rule for dimension accesses.
407 class SparseTensorToDimSizeConverter
408     : public OpConversionPattern<tensor::DimOp> {
409 public:
410   using OpConversionPattern::OpConversionPattern;
411   LogicalResult
412   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
413                   ConversionPatternRewriter &rewriter) const override {
414     // Only rewrite annotated DimOp with constant index.
415     auto enc = getSparseTensorEncoding(op.source().getType());
416     if (!enc)
417       return failure();
418     Optional<int64_t> index = op.getConstantIndex();
419     if (!index.hasValue())
420       return failure();
421     // Generate the call.
422     Value src = adaptor.getOperands()[0];
423     int64_t idx = index.getValue();
424     rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
425     return success();
426   }
427 };
428 
429 /// Sparse conversion rule for the new operator.
430 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
431   using OpConversionPattern::OpConversionPattern;
432   LogicalResult
433   matchAndRewrite(NewOp op, OpAdaptor adaptor,
434                   ConversionPatternRewriter &rewriter) const override {
435     Type resType = op.getType();
436     auto enc = getSparseTensorEncoding(resType);
437     if (!enc)
438       return failure();
439     // Generate the call to construct tensor from ptr. The sizes are
440     // inferred from the result type of the new operator.
441     SmallVector<Value, 4> sizes;
442     SmallVector<Value, 8> params;
443     sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
444     Value ptr = adaptor.getOperands()[0];
445     newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
446     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
447     return success();
448   }
449 };
450 
451 /// Sparse conversion rule for the init operator.
452 class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
453   using OpConversionPattern::OpConversionPattern;
454   LogicalResult
455   matchAndRewrite(InitOp op, OpAdaptor adaptor,
456                   ConversionPatternRewriter &rewriter) const override {
457     Type resType = op.getType();
458     auto enc = getSparseTensorEncoding(resType);
459     if (!enc)
460       return failure();
461     // Generate the call to construct empty tensor. The sizes are
462     // explicitly defined by the arguments to the init operator.
463     SmallVector<Value, 8> params;
464     newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
465     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
466     return success();
467   }
468 };
469 
470 /// Sparse conversion rule for the convert operator.
471 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
472   using OpConversionPattern::OpConversionPattern;
473   LogicalResult
474   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
475                   ConversionPatternRewriter &rewriter) const override {
476     Location loc = op->getLoc();
477     Type resType = op.getType();
478     Type srcType = op.source().getType();
479     auto encDst = getSparseTensorEncoding(resType);
480     auto encSrc = getSparseTensorEncoding(srcType);
481     Value src = adaptor.getOperands()[0];
482     if (encDst && encSrc) {
483       // This is a sparse => sparse conversion, which is handled as follows:
484       //   t = src->toCOO();         ; src to COO in dst order
485       //   dst = newSparseTensor(t)
486       // Using the coordinate scheme as an intermediate does not always
487       // yield the fastest conversion but avoids the need for a full
488       // O(N^2) conversion matrix.
489       SmallVector<Value, 4> sizes;
490       SmallVector<Value, 8> params;
491       sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
492                    src);
493       newParams(rewriter, params, op, encDst, kToCOO, sizes, src);
494       Value coo = genNewCall(rewriter, op, params);
495       params[6] = constantI32(rewriter, loc, kFromCOO);
496       params[7] = coo;
497       rewriter.replaceOp(op, genNewCall(rewriter, op, params));
498       return success();
499     }
500     if (!encDst || encSrc) {
501       // TODO: sparse => dense
502       return failure();
503     }
504     // This is a dense => sparse conversion or a sparse constant in COO =>
505     // sparse conversion, which is handled as follows:
506     //   t = newSparseCOO()
507     //   ...code to fill the COO tensor t...
508     //   s = newSparseTensor(t)
509     //
510     // To fill the COO tensor from a dense tensor:
511     //   for i1 in dim1
512     //    ..
513     //     for ik in dimk
514     //       val = a[i1,..,ik]
515     //       if val != 0
516     //         t->add(val, [i1,..,ik], [p1,..,pk])
517     //
518     // To fill the COO tensor from a sparse constant in COO format:
519     //   for i in range(NNZ)
520     //     val = values[i]
521     //     [i1,..,ik] = indices[i]
522     //     t->add(val, [i1,..,ik], [p1,..,pk])
523     //
524     // Note that the dense tensor traversal code is actually implemented
525     // using MLIR IR to avoid having to expose too much low-level
526     // memref traversal details to the runtime support library.
527     // Also note that the code below only generates the "new" ops and
528     // the loop-nest per se; whereas the entire body of the innermost
529     // loop is generated by genAddElt().
530     ShapedType stp = resType.cast<ShapedType>();
531     unsigned rank = stp.getRank();
532     SmallVector<Value, 4> sizes;
533     SmallVector<Value, 8> params;
534     sizesFromSrc(rewriter, sizes, loc, src);
535     newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
536     Value ptr = genNewCall(rewriter, op, params);
537     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
538     Value perm = params[2];
539     SmallVector<Value> lo;
540     SmallVector<Value> hi;
541     SmallVector<Value> st;
542     Value zero = constantIndex(rewriter, loc, 0);
543     Value one = constantIndex(rewriter, loc, 1);
544     auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
545     bool isCOOConstant = indicesValues.hasValue();
546     Value indices;
547     Value values;
548     if (isCOOConstant) {
549       indices = indicesValues->first;
550       values = indicesValues->second;
551       lo.push_back(zero);
552       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
553       st.push_back(one);
554     } else {
555       for (unsigned i = 0; i < rank; i++) {
556         lo.push_back(zero);
557         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
558         st.push_back(one);
559       }
560     }
561     Type eltType = stp.getElementType();
562     scf::buildLoopNest(
563         rewriter, op.getLoc(), lo, hi, st, {},
564         [&](OpBuilder &builder, Location loc, ValueRange ivs,
565             ValueRange args) -> scf::ValueVector {
566           Value val;
567           if (isCOOConstant)
568             val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
569                                             ivs, rank);
570           else
571             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
572           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
573           return {};
574         });
575     // Final call to construct sparse tensor storage.
576     params[6] = constantI32(rewriter, loc, kFromCOO);
577     params[7] = ptr;
578     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
579     return success();
580   }
581 };
582 
583 /// Sparse conversion rule for the release operator.
584 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
585 public:
586   using OpConversionPattern::OpConversionPattern;
587   LogicalResult
588   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
589                   ConversionPatternRewriter &rewriter) const override {
590     StringRef name = "delSparseTensor";
591     TypeRange none;
592     auto fn = getFunc(op, name, none, adaptor.getOperands());
593     rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
594     rewriter.eraseOp(op);
595     return success();
596   }
597 };
598 
599 /// Sparse conversion rule for pointer accesses.
600 class SparseTensorToPointersConverter
601     : public OpConversionPattern<ToPointersOp> {
602 public:
603   using OpConversionPattern::OpConversionPattern;
604   LogicalResult
605   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
606                   ConversionPatternRewriter &rewriter) const override {
607     Type resType = op.getType();
608     Type eltType = resType.cast<ShapedType>().getElementType();
609     StringRef name;
610     if (eltType.isIndex())
611       name = "sparsePointers"; // 64-bit, but its own name for unique signature
612     else if (eltType.isInteger(64))
613       name = "sparsePointers64";
614     else if (eltType.isInteger(32))
615       name = "sparsePointers32";
616     else if (eltType.isInteger(16))
617       name = "sparsePointers16";
618     else if (eltType.isInteger(8))
619       name = "sparsePointers8";
620     else
621       return failure();
622     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
623                       /*emitCInterface=*/true);
624     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
625     return success();
626   }
627 };
628 
629 /// Sparse conversion rule for index accesses.
630 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
631 public:
632   using OpConversionPattern::OpConversionPattern;
633   LogicalResult
634   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
635                   ConversionPatternRewriter &rewriter) const override {
636     Type resType = op.getType();
637     Type eltType = resType.cast<ShapedType>().getElementType();
638     StringRef name;
639     if (eltType.isIndex())
640       name = "sparseIndices"; // 64-bit, but its own name for unique signature
641     else if (eltType.isInteger(64))
642       name = "sparseIndices64";
643     else if (eltType.isInteger(32))
644       name = "sparseIndices32";
645     else if (eltType.isInteger(16))
646       name = "sparseIndices16";
647     else if (eltType.isInteger(8))
648       name = "sparseIndices8";
649     else
650       return failure();
651     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
652                       /*emitCInterface=*/true);
653     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
654     return success();
655   }
656 };
657 
658 /// Sparse conversion rule for value accesses.
659 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
660 public:
661   using OpConversionPattern::OpConversionPattern;
662   LogicalResult
663   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
664                   ConversionPatternRewriter &rewriter) const override {
665     Type resType = op.getType();
666     Type eltType = resType.cast<ShapedType>().getElementType();
667     StringRef name;
668     if (eltType.isF64())
669       name = "sparseValuesF64";
670     else if (eltType.isF32())
671       name = "sparseValuesF32";
672     else if (eltType.isInteger(64))
673       name = "sparseValuesI64";
674     else if (eltType.isInteger(32))
675       name = "sparseValuesI32";
676     else if (eltType.isInteger(16))
677       name = "sparseValuesI16";
678     else if (eltType.isInteger(8))
679       name = "sparseValuesI8";
680     else
681       return failure();
682     auto fn = getFunc(op, name, resType, adaptor.getOperands(),
683                       /*emitCInterface=*/true);
684     rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
685     return success();
686   }
687 };
688 
689 /// Sparse conversion rule for tensor reconstruction.
690 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
691 public:
692   using OpConversionPattern::OpConversionPattern;
693   LogicalResult
694   // Simply fold the operator into the pointer to the sparse storage scheme.
695   matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
696                   ConversionPatternRewriter &rewriter) const override {
697     // Check that all arguments of the tensor reconstruction operators are calls
698     // into the support library that query exactly the same opaque pointer.
699     Value ptr;
700     for (Value op : adaptor.getOperands()) {
701       if (auto call = op.getDefiningOp<CallOp>()) {
702         Value arg = call.getOperand(0);
703         if (!arg.getType().isa<LLVM::LLVMPointerType>())
704           return failure();
705         if (!ptr)
706           ptr = arg;
707         else if (arg != ptr)
708           return failure();
709       }
710     }
711     // If a single opaque pointer is found, perform the folding.
712     if (!ptr)
713       return failure();
714     rewriter.replaceOp(op, ptr);
715     return success();
716   }
717 };
718 
719 } // namespace
720 
721 //===----------------------------------------------------------------------===//
722 // Public method for populating conversion rules.
723 //===----------------------------------------------------------------------===//
724 
725 /// Populates the given patterns list with conversion rules required for
726 /// the sparsification of linear algebra operations.
727 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
728                                                   RewritePatternSet &patterns) {
729   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
730                SparseTensorNewConverter, SparseTensorInitConverter,
731                SparseTensorConvertConverter, SparseTensorReleaseConverter,
732                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
733                SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
734       typeConverter, patterns.getContext());
735 }
736