1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 //===----------------------------------------------------------------------===// 33 // Helper methods. 34 //===----------------------------------------------------------------------===// 35 36 /// Returns internal type encoding for primary storage. Keep these 37 /// values consistent with the sparse runtime support library. 38 static unsigned getPrimaryTypeEncoding(Type tp) { 39 if (tp.isF64()) 40 return 1; 41 if (tp.isF32()) 42 return 2; 43 if (tp.isInteger(64)) 44 return 3; 45 if (tp.isInteger(32)) 46 return 4; 47 if (tp.isInteger(16)) 48 return 5; 49 if (tp.isInteger(8)) 50 return 6; 51 return 0; 52 } 53 54 /// Returns internal type encoding for overhead storage. Keep these 55 /// values consistent with the sparse runtime support library. 56 static unsigned getOverheadTypeEncoding(unsigned width) { 57 switch (width) { 58 default: 59 return 1; 60 case 32: 61 return 2; 62 case 16: 63 return 3; 64 case 8: 65 return 4; 66 } 67 } 68 69 /// Returns internal dimension level type encoding. Keep these 70 /// values consistent with the sparse runtime support library. 71 static unsigned 72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 73 switch (dlt) { 74 case SparseTensorEncodingAttr::DimLevelType::Dense: 75 return 0; 76 case SparseTensorEncodingAttr::DimLevelType::Compressed: 77 return 1; 78 case SparseTensorEncodingAttr::DimLevelType::Singleton: 79 return 2; 80 } 81 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 82 } 83 84 /// Returns integers of given width and values as a constant tensor. 85 /// We cast the static shape into a dynamic shape to ensure that the 86 /// method signature remains uniform across different tensor dimensions. 87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 88 Location loc, ArrayRef<APInt> values) { 89 Type etp = rewriter.getIntegerType(width); 90 unsigned sz = values.size(); 91 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 92 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 93 auto elts = rewriter.create<arith::ConstantOp>( 94 loc, DenseElementsAttr::get(tt1, values)); 95 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 96 } 97 98 /// Returns a function reference (first hit also inserts into module). Sets 99 /// the "_emit_c_interface" on the function declaration when requested, 100 /// so that LLVM lowering generates a wrapper function that takes care 101 /// of ABI complications with passing in and returning MemRefs to C functions. 102 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 103 TypeRange resultType, ValueRange operands, 104 bool emitCInterface = false) { 105 MLIRContext *context = op->getContext(); 106 auto module = op->getParentOfType<ModuleOp>(); 107 auto result = SymbolRefAttr::get(context, name); 108 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 109 if (!func) { 110 OpBuilder moduleBuilder(module.getBodyRegion()); 111 func = moduleBuilder.create<FuncOp>( 112 op->getLoc(), name, 113 FunctionType::get(context, operands.getTypes(), resultType)); 114 func.setPrivate(); 115 if (emitCInterface) 116 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 117 } 118 return result; 119 } 120 121 /// Generates a call into the "swiss army knife" method of the sparse runtime 122 /// support library for materializing sparse tensors into the computation. The 123 /// method returns the call value and assigns the permutation to 'perm'. 124 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 125 SparseTensorEncodingAttr &enc, uint32_t action, 126 Value &perm, Value ptr = Value()) { 127 Location loc = op->getLoc(); 128 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 129 SmallVector<Value, 8> params; 130 // Sparsity annotations in tensor constant form. 131 SmallVector<APInt, 4> attrs; 132 unsigned sz = enc.getDimLevelType().size(); 133 for (unsigned i = 0; i < sz; i++) 134 attrs.push_back( 135 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 136 params.push_back(getTensor(rewriter, 8, loc, attrs)); 137 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 138 // verification of external data, or for construction of internal data. 139 auto shape = resType.getShape(); 140 SmallVector<APInt, 4> sizes; 141 for (unsigned i = 0; i < sz; i++) { 142 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 143 sizes.push_back(APInt(64, s)); 144 } 145 params.push_back(getTensor(rewriter, 64, loc, sizes)); 146 // Dimension order permutation array. This is the "identity" permutation by 147 // default, or otherwise the "reverse" permutation of a given ordering, so 148 // that indices can be mapped quickly to the right position. 149 SmallVector<APInt, 4> rev(sz); 150 if (AffineMap p = enc.getDimOrdering()) { 151 for (unsigned i = 0; i < sz; i++) 152 rev[p.getDimPosition(i)] = APInt(64, i); 153 } else { 154 for (unsigned i = 0; i < sz; i++) 155 rev[i] = APInt(64, i); 156 } 157 perm = getTensor(rewriter, 64, loc, rev); 158 params.push_back(perm); 159 // Secondary and primary types encoding. 160 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 161 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 162 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 163 assert(primary); 164 params.push_back(rewriter.create<arith::ConstantOp>( 165 loc, rewriter.getI64IntegerAttr(secPtr))); 166 params.push_back(rewriter.create<arith::ConstantOp>( 167 loc, rewriter.getI64IntegerAttr(secInd))); 168 params.push_back(rewriter.create<arith::ConstantOp>( 169 loc, rewriter.getI64IntegerAttr(primary))); 170 // User action and pointer. 171 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 172 if (!ptr) 173 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 174 params.push_back(rewriter.create<arith::ConstantOp>( 175 loc, rewriter.getI32IntegerAttr(action))); 176 params.push_back(ptr); 177 // Generate the call to create new tensor. 178 StringRef name = "newSparseTensor"; 179 auto call = rewriter.create<CallOp>( 180 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 181 params); 182 return call.getResult(0); 183 } 184 185 /// Generates a constant zero of the given type. 186 static Value getZero(ConversionPatternRewriter &rewriter, Location loc, 187 Type t) { 188 return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t)); 189 } 190 191 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 192 /// For floating types, we use the "unordered" comparator (i.e., returns 193 /// true if `v` is NaN). 194 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 195 Value v) { 196 Type t = v.getType(); 197 Value zero = getZero(rewriter, loc, t); 198 if (t.isa<FloatType>()) 199 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 200 zero); 201 if (t.isIntOrIndex()) 202 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 203 zero); 204 llvm_unreachable("Unknown element type"); 205 } 206 207 /// Generates the code to read the value from tensor[ivs], and conditionally 208 /// stores the indices ivs to the memory in ind. The generated code looks like 209 /// the following and the insertion point after this routine is inside the 210 /// if-then branch behind the assignment to ind. This is to ensure that the 211 /// addEltX call generated after is inside the if-then branch. 212 /// if (tensor[ivs]!=0) { 213 /// ind = ivs 214 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 215 Operation *op, Value tensor, Value ind, 216 ValueRange ivs) { 217 Location loc = op->getLoc(); 218 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 219 Value cond = genIsNonzero(rewriter, loc, val); 220 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 221 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 222 unsigned i = 0; 223 for (auto iv : ivs) { 224 Value idx = 225 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++)); 226 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 227 } 228 return val; 229 } 230 231 /// Generates a call that adds one element to a coordinate scheme. 232 /// In particular, this generates code like the following: 233 /// val = a[i1,..,ik]; 234 /// if val != 0 235 /// t->add(val, [i1,..,ik], [p1,..,pk]); 236 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 237 Type eltType, Value ptr, Value val, Value ind, 238 Value perm) { 239 Location loc = op->getLoc(); 240 StringRef name; 241 if (eltType.isF64()) 242 name = "addEltF64"; 243 else if (eltType.isF32()) 244 name = "addEltF32"; 245 else if (eltType.isInteger(64)) 246 name = "addEltI64"; 247 else if (eltType.isInteger(32)) 248 name = "addEltI32"; 249 else if (eltType.isInteger(16)) 250 name = "addEltI16"; 251 else if (eltType.isInteger(8)) 252 name = "addEltI8"; 253 else 254 llvm_unreachable("Unknown element type"); 255 SmallVector<Value, 8> params; 256 params.push_back(ptr); 257 params.push_back(val); 258 params.push_back(ind); 259 params.push_back(perm); 260 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 261 rewriter.create<CallOp>( 262 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 263 params); 264 } 265 266 /// If the tensor is a sparse constant, generates and returns the pair of 267 /// the constants for the indices and the values. 268 static Optional<std::pair<Value, Value>> 269 genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op, 270 Value tensor) { 271 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 272 if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) { 273 Location loc = op->getLoc(); 274 DenseElementsAttr indicesAttr = attr.getIndices(); 275 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 276 DenseElementsAttr valuesAttr = attr.getValues(); 277 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 278 return std::make_pair(indices, values); 279 } 280 } 281 return {}; 282 } 283 284 /// Generates the code to copy the index at indices[ivs] to ind, and return 285 /// the value at value[ivs]. 286 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 287 Operation *op, Value indices, 288 Value values, Value ind, ValueRange ivs, 289 unsigned rank) { 290 Location loc = op->getLoc(); 291 for (unsigned i = 0; i < rank; i++) { 292 Value idx = 293 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)); 294 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 295 ValueRange{ivs[0], idx}); 296 val = 297 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 298 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 299 } 300 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 301 } 302 303 /// Generates code to stack-allocate a `memref<?xindex>` where the `?` 304 /// is the given `rank`. This array is intended to serve as a reusable 305 /// buffer for storing the indices of a single tensor element, to avoid 306 /// allocation in the body of loops. 307 static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc, 308 int64_t rank) { 309 auto indexTp = rewriter.getIndexType(); 310 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp); 311 Value arg = 312 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank)); 313 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg}); 314 } 315 316 //===----------------------------------------------------------------------===// 317 // Conversion rules. 318 //===----------------------------------------------------------------------===// 319 320 /// Sparse conversion rule for returns. 321 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 322 public: 323 using OpConversionPattern::OpConversionPattern; 324 LogicalResult 325 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 326 ConversionPatternRewriter &rewriter) const override { 327 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 328 return success(); 329 } 330 }; 331 332 /// Sparse conversion rule for dimension accesses. 333 class SparseTensorToDimSizeConverter 334 : public OpConversionPattern<tensor::DimOp> { 335 public: 336 using OpConversionPattern::OpConversionPattern; 337 LogicalResult 338 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 Type resType = op.getType(); 341 auto enc = getSparseTensorEncoding(op.source().getType()); 342 if (!enc) 343 return failure(); 344 // Permute the dim index. 345 Optional<int64_t> index = op.getConstantIndex(); 346 if (!index.hasValue()) 347 return failure(); 348 int64_t idx = index.getValue(); 349 if (AffineMap p = enc.getDimOrdering()) 350 idx = p.getPermutedPosition(idx); 351 // Generate the call. 352 StringRef name = "sparseDimSize"; 353 SmallVector<Value, 2> params; 354 params.push_back(adaptor.getOperands()[0]); 355 params.push_back(rewriter.create<arith::ConstantOp>( 356 op.getLoc(), rewriter.getIndexAttr(idx))); 357 rewriter.replaceOpWithNewOp<CallOp>( 358 op, resType, getFunc(op, name, resType, params), params); 359 return success(); 360 } 361 }; 362 363 /// Sparse conversion rule for the new operator. 364 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 365 using OpConversionPattern::OpConversionPattern; 366 LogicalResult 367 matchAndRewrite(NewOp op, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const override { 369 Type resType = op.getType(); 370 auto enc = getSparseTensorEncoding(resType); 371 if (!enc) 372 return failure(); 373 Value perm; 374 rewriter.replaceOp( 375 op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0])); 376 return success(); 377 } 378 }; 379 380 /// Sparse conversion rule for the convert operator. 381 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 382 using OpConversionPattern::OpConversionPattern; 383 LogicalResult 384 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 385 ConversionPatternRewriter &rewriter) const override { 386 Type resType = op.getType(); 387 auto encDst = getSparseTensorEncoding(resType); 388 auto encSrc = getSparseTensorEncoding(op.source().getType()); 389 auto src = adaptor.getOperands()[0]; 390 if (encDst && encSrc) { 391 // This is a sparse => sparse conversion, which is handled as follows: 392 // t = src->toCOO(); ; src to COO in dst order 393 // dst = newSparseTensor(t) 394 // Using the coordinate scheme as an intermediate does not always 395 // yield the fastest conversion but avoids the need for a full 396 // O(N^2) conversion matrix. 397 Value perm; 398 Value coo = genNewCall(rewriter, op, encDst, 3, perm, src); 399 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); 400 return success(); 401 } 402 if (!encDst || encSrc) { 403 // TODO: sparse => dense 404 return failure(); 405 } 406 // This is a dense => sparse conversion or a sparse constant in COO => 407 // sparse conversion, which is handled as follows: 408 // t = newSparseCOO() 409 // ...code to fill the COO tensor t... 410 // s = newSparseTensor(t) 411 // 412 // To fill the COO tensor from a dense tensor: 413 // for i1 in dim1 414 // .. 415 // for ik in dimk 416 // val = a[i1,..,ik] 417 // if val != 0 418 // t->add(val, [i1,..,ik], [p1,..,pk]) 419 // 420 // To fill the COO tensor from a sparse constant in COO format: 421 // for i in range(NNZ) 422 // val = values[i] 423 // [i1,..,ik] = indices[i] 424 // t->add(val, [i1,..,ik], [p1,..,pk]) 425 // 426 // Note that the dense tensor traversal code is actually implemented 427 // using MLIR IR to avoid having to expose too much low-level 428 // memref traversal details to the runtime support library. 429 // Also note that the code below only generates the "new" ops and 430 // the loop-nest per se; whereas the entire body of the innermost 431 // loop is generated by genAddElt(). 432 Location loc = op->getLoc(); 433 ShapedType shape = resType.cast<ShapedType>(); 434 Value perm; 435 Value ptr = genNewCall(rewriter, op, encDst, 2, perm); 436 Value ind = allocaIndices(rewriter, loc, shape.getRank()); 437 SmallVector<Value> lo; 438 SmallVector<Value> hi; 439 SmallVector<Value> st; 440 Value zero = 441 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); 442 Value one = 443 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1)); 444 auto indicesValues = genSplitSparseConstant(rewriter, op, src); 445 bool isCOOConstant = indicesValues.hasValue(); 446 Value indices; 447 Value values; 448 if (isCOOConstant) { 449 indices = indicesValues->first; 450 values = indicesValues->second; 451 lo.push_back(zero); 452 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 453 st.push_back(one); 454 } else { 455 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 456 lo.push_back(zero); 457 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 458 st.push_back(one); 459 } 460 } 461 Type eltType = shape.getElementType(); 462 unsigned rank = shape.getRank(); 463 scf::buildLoopNest( 464 rewriter, op.getLoc(), lo, hi, st, {}, 465 [&](OpBuilder &builder, Location loc, ValueRange ivs, 466 ValueRange args) -> scf::ValueVector { 467 Value val; 468 if (isCOOConstant) 469 val = genIndexAndValueForSparse(rewriter, op, indices, values, ind, 470 ivs, rank); 471 else 472 val = genIndexAndValueForDense(rewriter, op, src, ind, ivs); 473 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 474 return {}; 475 }); 476 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); 477 return success(); 478 } 479 }; 480 481 /// Sparse conversion rule for the release operator. 482 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 483 public: 484 using OpConversionPattern::OpConversionPattern; 485 LogicalResult 486 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 487 ConversionPatternRewriter &rewriter) const override { 488 StringRef name = "delSparseTensor"; 489 TypeRange none; 490 rewriter.create<CallOp>(op.getLoc(), none, 491 getFunc(op, name, none, adaptor.getOperands()), 492 adaptor.getOperands()); 493 rewriter.eraseOp(op); 494 return success(); 495 } 496 }; 497 498 /// Sparse conversion rule for pointer accesses. 499 class SparseTensorToPointersConverter 500 : public OpConversionPattern<ToPointersOp> { 501 public: 502 using OpConversionPattern::OpConversionPattern; 503 LogicalResult 504 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 505 ConversionPatternRewriter &rewriter) const override { 506 Type resType = op.getType(); 507 Type eltType = resType.cast<ShapedType>().getElementType(); 508 StringRef name; 509 if (eltType.isIndex()) 510 name = "sparsePointers"; // 64-bit, but its own name for unique signature 511 else if (eltType.isInteger(64)) 512 name = "sparsePointers64"; 513 else if (eltType.isInteger(32)) 514 name = "sparsePointers32"; 515 else if (eltType.isInteger(16)) 516 name = "sparsePointers16"; 517 else if (eltType.isInteger(8)) 518 name = "sparsePointers8"; 519 else 520 return failure(); 521 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 522 getFunc(op, name, resType, 523 adaptor.getOperands(), 524 /*emitCInterface=*/true), 525 adaptor.getOperands()); 526 return success(); 527 } 528 }; 529 530 /// Sparse conversion rule for index accesses. 531 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 532 public: 533 using OpConversionPattern::OpConversionPattern; 534 LogicalResult 535 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 536 ConversionPatternRewriter &rewriter) const override { 537 Type resType = op.getType(); 538 Type eltType = resType.cast<ShapedType>().getElementType(); 539 StringRef name; 540 if (eltType.isIndex()) 541 name = "sparseIndices"; // 64-bit, but its own name for unique signature 542 else if (eltType.isInteger(64)) 543 name = "sparseIndices64"; 544 else if (eltType.isInteger(32)) 545 name = "sparseIndices32"; 546 else if (eltType.isInteger(16)) 547 name = "sparseIndices16"; 548 else if (eltType.isInteger(8)) 549 name = "sparseIndices8"; 550 else 551 return failure(); 552 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 553 getFunc(op, name, resType, 554 adaptor.getOperands(), 555 /*emitCInterface=*/true), 556 adaptor.getOperands()); 557 return success(); 558 } 559 }; 560 561 /// Sparse conversion rule for value accesses. 562 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 563 public: 564 using OpConversionPattern::OpConversionPattern; 565 LogicalResult 566 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 567 ConversionPatternRewriter &rewriter) const override { 568 Type resType = op.getType(); 569 Type eltType = resType.cast<ShapedType>().getElementType(); 570 StringRef name; 571 if (eltType.isF64()) 572 name = "sparseValuesF64"; 573 else if (eltType.isF32()) 574 name = "sparseValuesF32"; 575 else if (eltType.isInteger(64)) 576 name = "sparseValuesI64"; 577 else if (eltType.isInteger(32)) 578 name = "sparseValuesI32"; 579 else if (eltType.isInteger(16)) 580 name = "sparseValuesI16"; 581 else if (eltType.isInteger(8)) 582 name = "sparseValuesI8"; 583 else 584 return failure(); 585 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 586 getFunc(op, name, resType, 587 adaptor.getOperands(), 588 /*emitCInterface=*/true), 589 adaptor.getOperands()); 590 return success(); 591 } 592 }; 593 594 /// Sparse conversion rule for tensor reconstruction. 595 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 596 public: 597 using OpConversionPattern::OpConversionPattern; 598 LogicalResult 599 // Simply fold the operator into the pointer to the sparse storage scheme. 600 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 601 ConversionPatternRewriter &rewriter) const override { 602 // Check that all arguments of the tensor reconstruction operators are calls 603 // into the support library that query exactly the same opaque pointer. 604 Value ptr; 605 for (Value op : adaptor.getOperands()) { 606 if (auto call = op.getDefiningOp<CallOp>()) { 607 Value arg = call.getOperand(0); 608 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 609 return failure(); 610 if (!ptr) 611 ptr = arg; 612 else if (arg != ptr) 613 return failure(); 614 } 615 } 616 // If a single opaque pointer is found, perform the folding. 617 if (!ptr) 618 return failure(); 619 rewriter.replaceOp(op, ptr); 620 return success(); 621 } 622 }; 623 624 } // namespace 625 626 //===----------------------------------------------------------------------===// 627 // Public method for populating conversion rules. 628 //===----------------------------------------------------------------------===// 629 630 /// Populates the given patterns list with conversion rules required for 631 /// the sparsification of linear algebra operations. 632 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 633 RewritePatternSet &patterns) { 634 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 635 SparseTensorNewConverter, SparseTensorConvertConverter, 636 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 637 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 638 SparseTensorToTensorConverter>(typeConverter, 639 patterns.getContext()); 640 } 641