1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 //===----------------------------------------------------------------------===// 33 // Helper methods. 34 //===----------------------------------------------------------------------===// 35 36 /// Returns internal type encoding for primary storage. Keep these 37 /// values consistent with the sparse runtime support library. 38 static unsigned getPrimaryTypeEncoding(Type tp) { 39 if (tp.isF64()) 40 return 1; 41 if (tp.isF32()) 42 return 2; 43 if (tp.isInteger(64)) 44 return 3; 45 if (tp.isInteger(32)) 46 return 4; 47 if (tp.isInteger(16)) 48 return 5; 49 if (tp.isInteger(8)) 50 return 6; 51 return 0; 52 } 53 54 /// Returns internal type encoding for overhead storage. Keep these 55 /// values consistent with the sparse runtime support library. 56 static unsigned getOverheadTypeEncoding(unsigned width) { 57 switch (width) { 58 default: 59 return 1; 60 case 32: 61 return 2; 62 case 16: 63 return 3; 64 case 8: 65 return 4; 66 } 67 } 68 69 /// Returns internal dimension level type encoding. Keep these 70 /// values consistent with the sparse runtime support library. 71 static unsigned 72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 73 switch (dlt) { 74 case SparseTensorEncodingAttr::DimLevelType::Dense: 75 return 0; 76 case SparseTensorEncodingAttr::DimLevelType::Compressed: 77 return 1; 78 case SparseTensorEncodingAttr::DimLevelType::Singleton: 79 return 2; 80 } 81 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 82 } 83 84 /// Generates a constant zero of the given type. 85 inline static Value constantZero(ConversionPatternRewriter &rewriter, 86 Location loc, Type t) { 87 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 88 } 89 90 /// Generates a constant of `index` type. 91 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 92 Location loc, unsigned i) { 93 return rewriter.create<arith::ConstantIndexOp>(loc, i); 94 } 95 96 /// Generates a constant of `i64` type. 97 inline static Value constantI64(ConversionPatternRewriter &rewriter, 98 Location loc, int64_t i) { 99 return rewriter.create<arith::ConstantIntOp>(loc, i, 64); 100 } 101 102 /// Generates a constant of `i32` type. 103 inline static Value constantI32(ConversionPatternRewriter &rewriter, 104 Location loc, int32_t i) { 105 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 106 } 107 108 /// Returns integers of given width and values as a constant tensor. 109 /// We cast the static shape into a dynamic shape to ensure that the 110 /// method signature remains uniform across different tensor dimensions. 111 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 112 Location loc, ArrayRef<APInt> values) { 113 Type etp = rewriter.getIntegerType(width); 114 unsigned sz = values.size(); 115 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 116 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 117 auto elts = rewriter.create<arith::ConstantOp>( 118 loc, DenseElementsAttr::get(tt1, values)); 119 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 120 } 121 122 /// Returns a function reference (first hit also inserts into module). Sets 123 /// the "_emit_c_interface" on the function declaration when requested, 124 /// so that LLVM lowering generates a wrapper function that takes care 125 /// of ABI complications with passing in and returning MemRefs to C functions. 126 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 127 TypeRange resultType, ValueRange operands, 128 bool emitCInterface = false) { 129 MLIRContext *context = op->getContext(); 130 auto module = op->getParentOfType<ModuleOp>(); 131 auto result = SymbolRefAttr::get(context, name); 132 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 133 if (!func) { 134 OpBuilder moduleBuilder(module.getBodyRegion()); 135 func = moduleBuilder.create<FuncOp>( 136 op->getLoc(), name, 137 FunctionType::get(context, operands.getTypes(), resultType)); 138 func.setPrivate(); 139 if (emitCInterface) 140 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 141 } 142 return result; 143 } 144 145 /// Generates a call into the "swiss army knife" method of the sparse runtime 146 /// support library for materializing sparse tensors into the computation. The 147 /// method returns the call value and assigns the permutation to 'perm'. 148 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 149 SparseTensorEncodingAttr &enc, uint32_t action, 150 Value &perm, Value ptr = Value()) { 151 Location loc = op->getLoc(); 152 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 153 SmallVector<Value, 8> params; 154 // Sparsity annotations in tensor constant form. 155 SmallVector<APInt, 4> attrs; 156 unsigned sz = enc.getDimLevelType().size(); 157 for (unsigned i = 0; i < sz; i++) 158 attrs.push_back( 159 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 160 params.push_back(getTensor(rewriter, 8, loc, attrs)); 161 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 162 // verification of external data, or for construction of internal data. 163 auto shape = resType.getShape(); 164 SmallVector<APInt, 4> sizes; 165 for (unsigned i = 0; i < sz; i++) { 166 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 167 sizes.push_back(APInt(64, s)); 168 } 169 params.push_back(getTensor(rewriter, 64, loc, sizes)); 170 // Dimension order permutation array. This is the "identity" permutation by 171 // default, or otherwise the "reverse" permutation of a given ordering, so 172 // that indices can be mapped quickly to the right position. 173 SmallVector<APInt, 4> rev(sz); 174 if (AffineMap p = enc.getDimOrdering()) { 175 for (unsigned i = 0; i < sz; i++) 176 rev[p.getDimPosition(i)] = APInt(64, i); 177 } else { 178 for (unsigned i = 0; i < sz; i++) 179 rev[i] = APInt(64, i); 180 } 181 perm = getTensor(rewriter, 64, loc, rev); 182 params.push_back(perm); 183 // Secondary and primary types encoding. 184 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 185 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 186 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 187 assert(primary); 188 params.push_back(constantI64(rewriter, loc, secPtr)); 189 params.push_back(constantI64(rewriter, loc, secInd)); 190 params.push_back(constantI64(rewriter, loc, primary)); 191 // User action and pointer. 192 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 193 if (!ptr) 194 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 195 params.push_back(constantI32(rewriter, loc, action)); 196 params.push_back(ptr); 197 // Generate the call to create new tensor. 198 StringRef name = "newSparseTensor"; 199 auto call = rewriter.create<CallOp>( 200 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 201 params); 202 return call.getResult(0); 203 } 204 205 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 206 /// For floating types, we use the "unordered" comparator (i.e., returns 207 /// true if `v` is NaN). 208 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 209 Value v) { 210 Type t = v.getType(); 211 Value zero = constantZero(rewriter, loc, t); 212 if (t.isa<FloatType>()) 213 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 214 zero); 215 if (t.isIntOrIndex()) 216 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 217 zero); 218 llvm_unreachable("Unknown element type"); 219 } 220 221 /// Generates the code to read the value from tensor[ivs], and conditionally 222 /// stores the indices ivs to the memory in ind. The generated code looks like 223 /// the following and the insertion point after this routine is inside the 224 /// if-then branch behind the assignment to ind. This is to ensure that the 225 /// addEltX call generated after is inside the if-then branch. 226 /// if (tensor[ivs]!=0) { 227 /// ind = ivs 228 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 229 Location loc, Value tensor, Value ind, 230 ValueRange ivs) { 231 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 232 Value cond = genIsNonzero(rewriter, loc, val); 233 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 234 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 235 unsigned i = 0; 236 for (auto iv : ivs) { 237 Value idx = constantIndex(rewriter, loc, i++); 238 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 239 } 240 return val; 241 } 242 243 /// Generates a call that adds one element to a coordinate scheme. 244 /// In particular, this generates code like the following: 245 /// val = a[i1,..,ik]; 246 /// if val != 0 247 /// t->add(val, [i1,..,ik], [p1,..,pk]); 248 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 249 Type eltType, Value ptr, Value val, Value ind, 250 Value perm) { 251 Location loc = op->getLoc(); 252 StringRef name; 253 if (eltType.isF64()) 254 name = "addEltF64"; 255 else if (eltType.isF32()) 256 name = "addEltF32"; 257 else if (eltType.isInteger(64)) 258 name = "addEltI64"; 259 else if (eltType.isInteger(32)) 260 name = "addEltI32"; 261 else if (eltType.isInteger(16)) 262 name = "addEltI16"; 263 else if (eltType.isInteger(8)) 264 name = "addEltI8"; 265 else 266 llvm_unreachable("Unknown element type"); 267 SmallVector<Value, 8> params; 268 params.push_back(ptr); 269 params.push_back(val); 270 params.push_back(ind); 271 params.push_back(perm); 272 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 273 rewriter.create<CallOp>( 274 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 275 params); 276 } 277 278 /// If the tensor is a sparse constant, generates and returns the pair of 279 /// the constants for the indices and the values. 280 static Optional<std::pair<Value, Value>> 281 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 282 Value tensor) { 283 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 284 if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) { 285 DenseElementsAttr indicesAttr = attr.getIndices(); 286 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 287 DenseElementsAttr valuesAttr = attr.getValues(); 288 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 289 return std::make_pair(indices, values); 290 } 291 } 292 return {}; 293 } 294 295 /// Generates the code to copy the index at indices[ivs] to ind, and return 296 /// the value at value[ivs]. 297 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 298 Location loc, Value indices, 299 Value values, Value ind, ValueRange ivs, 300 unsigned rank) { 301 for (unsigned i = 0; i < rank; i++) { 302 Value idx = constantIndex(rewriter, loc, i); 303 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 304 ValueRange{ivs[0], idx}); 305 val = 306 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 307 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 308 } 309 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 310 } 311 312 /// Generates code to stack-allocate a `memref<?xindex>` where the `?` 313 /// is the given `rank`. This array is intended to serve as a reusable 314 /// buffer for storing the indices of a single tensor element, to avoid 315 /// allocation in the body of loops. 316 static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc, 317 int64_t rank) { 318 auto indexTp = rewriter.getIndexType(); 319 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); 320 Value arg = constantIndex(rewriter, loc, rank); 321 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg}); 322 } 323 324 //===----------------------------------------------------------------------===// 325 // Conversion rules. 326 //===----------------------------------------------------------------------===// 327 328 /// Sparse conversion rule for returns. 329 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 330 public: 331 using OpConversionPattern::OpConversionPattern; 332 LogicalResult 333 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 334 ConversionPatternRewriter &rewriter) const override { 335 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 336 return success(); 337 } 338 }; 339 340 /// Sparse conversion rule for dimension accesses. 341 class SparseTensorToDimSizeConverter 342 : public OpConversionPattern<tensor::DimOp> { 343 public: 344 using OpConversionPattern::OpConversionPattern; 345 LogicalResult 346 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 347 ConversionPatternRewriter &rewriter) const override { 348 Type resType = op.getType(); 349 auto enc = getSparseTensorEncoding(op.source().getType()); 350 if (!enc) 351 return failure(); 352 // Permute the dim index. 353 Optional<int64_t> index = op.getConstantIndex(); 354 if (!index.hasValue()) 355 return failure(); 356 int64_t idx = index.getValue(); 357 if (AffineMap p = enc.getDimOrdering()) 358 idx = p.getPermutedPosition(idx); 359 // Generate the call. 360 StringRef name = "sparseDimSize"; 361 SmallVector<Value, 2> params; 362 params.push_back(adaptor.getOperands()[0]); 363 params.push_back(constantIndex(rewriter, op.getLoc(), idx)); 364 rewriter.replaceOpWithNewOp<CallOp>( 365 op, resType, getFunc(op, name, resType, params), params); 366 return success(); 367 } 368 }; 369 370 /// Sparse conversion rule for the new operator. 371 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 372 using OpConversionPattern::OpConversionPattern; 373 LogicalResult 374 matchAndRewrite(NewOp op, OpAdaptor adaptor, 375 ConversionPatternRewriter &rewriter) const override { 376 Type resType = op.getType(); 377 auto enc = getSparseTensorEncoding(resType); 378 if (!enc) 379 return failure(); 380 Value perm; 381 rewriter.replaceOp( 382 op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0])); 383 return success(); 384 } 385 }; 386 387 /// Sparse conversion rule for the convert operator. 388 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 389 using OpConversionPattern::OpConversionPattern; 390 LogicalResult 391 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 Type resType = op.getType(); 394 auto encDst = getSparseTensorEncoding(resType); 395 auto encSrc = getSparseTensorEncoding(op.source().getType()); 396 auto src = adaptor.getOperands()[0]; 397 if (encDst && encSrc) { 398 // This is a sparse => sparse conversion, which is handled as follows: 399 // t = src->toCOO(); ; src to COO in dst order 400 // dst = newSparseTensor(t) 401 // Using the coordinate scheme as an intermediate does not always 402 // yield the fastest conversion but avoids the need for a full 403 // O(N^2) conversion matrix. 404 Value perm; 405 Value coo = genNewCall(rewriter, op, encDst, 3, perm, src); 406 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); 407 return success(); 408 } 409 if (!encDst || encSrc) { 410 // TODO: sparse => dense 411 return failure(); 412 } 413 // This is a dense => sparse conversion or a sparse constant in COO => 414 // sparse conversion, which is handled as follows: 415 // t = newSparseCOO() 416 // ...code to fill the COO tensor t... 417 // s = newSparseTensor(t) 418 // 419 // To fill the COO tensor from a dense tensor: 420 // for i1 in dim1 421 // .. 422 // for ik in dimk 423 // val = a[i1,..,ik] 424 // if val != 0 425 // t->add(val, [i1,..,ik], [p1,..,pk]) 426 // 427 // To fill the COO tensor from a sparse constant in COO format: 428 // for i in range(NNZ) 429 // val = values[i] 430 // [i1,..,ik] = indices[i] 431 // t->add(val, [i1,..,ik], [p1,..,pk]) 432 // 433 // Note that the dense tensor traversal code is actually implemented 434 // using MLIR IR to avoid having to expose too much low-level 435 // memref traversal details to the runtime support library. 436 // Also note that the code below only generates the "new" ops and 437 // the loop-nest per se; whereas the entire body of the innermost 438 // loop is generated by genAddElt(). 439 Location loc = op->getLoc(); 440 ShapedType shape = resType.cast<ShapedType>(); 441 Value perm; 442 Value ptr = genNewCall(rewriter, op, encDst, 2, perm); 443 Value ind = allocaIndices(rewriter, loc, shape.getRank()); 444 SmallVector<Value> lo; 445 SmallVector<Value> hi; 446 SmallVector<Value> st; 447 Value zero = constantIndex(rewriter, loc, 0); 448 Value one = constantIndex(rewriter, loc, 1); 449 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 450 bool isCOOConstant = indicesValues.hasValue(); 451 Value indices; 452 Value values; 453 if (isCOOConstant) { 454 indices = indicesValues->first; 455 values = indicesValues->second; 456 lo.push_back(zero); 457 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 458 st.push_back(one); 459 } else { 460 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 461 lo.push_back(zero); 462 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 463 st.push_back(one); 464 } 465 } 466 Type eltType = shape.getElementType(); 467 unsigned rank = shape.getRank(); 468 scf::buildLoopNest( 469 rewriter, op.getLoc(), lo, hi, st, {}, 470 [&](OpBuilder &builder, Location loc, ValueRange ivs, 471 ValueRange args) -> scf::ValueVector { 472 Value val; 473 if (isCOOConstant) 474 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 475 ivs, rank); 476 else 477 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 478 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 479 return {}; 480 }); 481 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); 482 return success(); 483 } 484 }; 485 486 /// Sparse conversion rule for the release operator. 487 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 488 public: 489 using OpConversionPattern::OpConversionPattern; 490 LogicalResult 491 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 492 ConversionPatternRewriter &rewriter) const override { 493 StringRef name = "delSparseTensor"; 494 TypeRange none; 495 rewriter.create<CallOp>(op.getLoc(), none, 496 getFunc(op, name, none, adaptor.getOperands()), 497 adaptor.getOperands()); 498 rewriter.eraseOp(op); 499 return success(); 500 } 501 }; 502 503 /// Sparse conversion rule for pointer accesses. 504 class SparseTensorToPointersConverter 505 : public OpConversionPattern<ToPointersOp> { 506 public: 507 using OpConversionPattern::OpConversionPattern; 508 LogicalResult 509 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 510 ConversionPatternRewriter &rewriter) const override { 511 Type resType = op.getType(); 512 Type eltType = resType.cast<ShapedType>().getElementType(); 513 StringRef name; 514 if (eltType.isIndex()) 515 name = "sparsePointers"; // 64-bit, but its own name for unique signature 516 else if (eltType.isInteger(64)) 517 name = "sparsePointers64"; 518 else if (eltType.isInteger(32)) 519 name = "sparsePointers32"; 520 else if (eltType.isInteger(16)) 521 name = "sparsePointers16"; 522 else if (eltType.isInteger(8)) 523 name = "sparsePointers8"; 524 else 525 return failure(); 526 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 527 getFunc(op, name, resType, 528 adaptor.getOperands(), 529 /*emitCInterface=*/true), 530 adaptor.getOperands()); 531 return success(); 532 } 533 }; 534 535 /// Sparse conversion rule for index accesses. 536 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 537 public: 538 using OpConversionPattern::OpConversionPattern; 539 LogicalResult 540 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 541 ConversionPatternRewriter &rewriter) const override { 542 Type resType = op.getType(); 543 Type eltType = resType.cast<ShapedType>().getElementType(); 544 StringRef name; 545 if (eltType.isIndex()) 546 name = "sparseIndices"; // 64-bit, but its own name for unique signature 547 else if (eltType.isInteger(64)) 548 name = "sparseIndices64"; 549 else if (eltType.isInteger(32)) 550 name = "sparseIndices32"; 551 else if (eltType.isInteger(16)) 552 name = "sparseIndices16"; 553 else if (eltType.isInteger(8)) 554 name = "sparseIndices8"; 555 else 556 return failure(); 557 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 558 getFunc(op, name, resType, 559 adaptor.getOperands(), 560 /*emitCInterface=*/true), 561 adaptor.getOperands()); 562 return success(); 563 } 564 }; 565 566 /// Sparse conversion rule for value accesses. 567 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 568 public: 569 using OpConversionPattern::OpConversionPattern; 570 LogicalResult 571 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 572 ConversionPatternRewriter &rewriter) const override { 573 Type resType = op.getType(); 574 Type eltType = resType.cast<ShapedType>().getElementType(); 575 StringRef name; 576 if (eltType.isF64()) 577 name = "sparseValuesF64"; 578 else if (eltType.isF32()) 579 name = "sparseValuesF32"; 580 else if (eltType.isInteger(64)) 581 name = "sparseValuesI64"; 582 else if (eltType.isInteger(32)) 583 name = "sparseValuesI32"; 584 else if (eltType.isInteger(16)) 585 name = "sparseValuesI16"; 586 else if (eltType.isInteger(8)) 587 name = "sparseValuesI8"; 588 else 589 return failure(); 590 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 591 getFunc(op, name, resType, 592 adaptor.getOperands(), 593 /*emitCInterface=*/true), 594 adaptor.getOperands()); 595 return success(); 596 } 597 }; 598 599 /// Sparse conversion rule for tensor reconstruction. 600 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 601 public: 602 using OpConversionPattern::OpConversionPattern; 603 LogicalResult 604 // Simply fold the operator into the pointer to the sparse storage scheme. 605 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 606 ConversionPatternRewriter &rewriter) const override { 607 // Check that all arguments of the tensor reconstruction operators are calls 608 // into the support library that query exactly the same opaque pointer. 609 Value ptr; 610 for (Value op : adaptor.getOperands()) { 611 if (auto call = op.getDefiningOp<CallOp>()) { 612 Value arg = call.getOperand(0); 613 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 614 return failure(); 615 if (!ptr) 616 ptr = arg; 617 else if (arg != ptr) 618 return failure(); 619 } 620 } 621 // If a single opaque pointer is found, perform the folding. 622 if (!ptr) 623 return failure(); 624 rewriter.replaceOp(op, ptr); 625 return success(); 626 } 627 }; 628 629 } // namespace 630 631 //===----------------------------------------------------------------------===// 632 // Public method for populating conversion rules. 633 //===----------------------------------------------------------------------===// 634 635 /// Populates the given patterns list with conversion rules required for 636 /// the sparsification of linear algebra operations. 637 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 638 RewritePatternSet &patterns) { 639 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 640 SparseTensorNewConverter, SparseTensorConvertConverter, 641 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 642 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 643 SparseTensorToTensorConverter>(typeConverter, 644 patterns.getContext()); 645 } 646