1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 /// New tensor storage action. Keep these values consistent with 33 /// the sparse runtime support library. 34 enum Action : uint32_t { 35 kEmpty = 0, 36 kFromFile = 1, 37 kFromCOO = 2, 38 kEmptyCOO = 3, 39 kToCOO = 4 40 }; 41 42 //===----------------------------------------------------------------------===// 43 // Helper methods. 44 //===----------------------------------------------------------------------===// 45 46 /// Returns internal type encoding for primary storage. Keep these 47 /// values consistent with the sparse runtime support library. 48 static unsigned getPrimaryTypeEncoding(Type tp) { 49 if (tp.isF64()) 50 return 1; 51 if (tp.isF32()) 52 return 2; 53 if (tp.isInteger(64)) 54 return 3; 55 if (tp.isInteger(32)) 56 return 4; 57 if (tp.isInteger(16)) 58 return 5; 59 if (tp.isInteger(8)) 60 return 6; 61 return 0; 62 } 63 64 /// Returns internal type encoding for overhead storage. Keep these 65 /// values consistent with the sparse runtime support library. 66 static unsigned getOverheadTypeEncoding(unsigned width) { 67 switch (width) { 68 default: 69 return 1; 70 case 32: 71 return 2; 72 case 16: 73 return 3; 74 case 8: 75 return 4; 76 } 77 } 78 79 /// Returns internal dimension level type encoding. Keep these 80 /// values consistent with the sparse runtime support library. 81 static unsigned 82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 83 switch (dlt) { 84 case SparseTensorEncodingAttr::DimLevelType::Dense: 85 return 0; 86 case SparseTensorEncodingAttr::DimLevelType::Compressed: 87 return 1; 88 case SparseTensorEncodingAttr::DimLevelType::Singleton: 89 return 2; 90 } 91 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 92 } 93 94 /// Generates a constant zero of the given type. 95 inline static Value constantZero(ConversionPatternRewriter &rewriter, 96 Location loc, Type t) { 97 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 98 } 99 100 /// Generates a constant of `index` type. 101 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 102 Location loc, unsigned i) { 103 return rewriter.create<arith::ConstantIndexOp>(loc, i); 104 } 105 106 /// Generates a constant of `i64` type. 107 inline static Value constantI64(ConversionPatternRewriter &rewriter, 108 Location loc, int64_t i) { 109 return rewriter.create<arith::ConstantIntOp>(loc, i, 64); 110 } 111 112 /// Generates a constant of `i32` type. 113 inline static Value constantI32(ConversionPatternRewriter &rewriter, 114 Location loc, int32_t i) { 115 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 116 } 117 118 /// Generates a constant of `i8` type. 119 inline static Value constantI8(ConversionPatternRewriter &rewriter, 120 Location loc, int8_t i) { 121 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 122 } 123 124 /// Returns a function reference (first hit also inserts into module). Sets 125 /// the "_emit_c_interface" on the function declaration when requested, 126 /// so that LLVM lowering generates a wrapper function that takes care 127 /// of ABI complications with passing in and returning MemRefs to C functions. 128 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 129 TypeRange resultType, ValueRange operands, 130 bool emitCInterface = false) { 131 MLIRContext *context = op->getContext(); 132 auto module = op->getParentOfType<ModuleOp>(); 133 auto result = SymbolRefAttr::get(context, name); 134 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 135 if (!func) { 136 OpBuilder moduleBuilder(module.getBodyRegion()); 137 func = moduleBuilder.create<FuncOp>( 138 op->getLoc(), name, 139 FunctionType::get(context, operands.getTypes(), resultType)); 140 func.setPrivate(); 141 if (emitCInterface) 142 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 143 } 144 return result; 145 } 146 147 /// Generates a temporary buffer of the given size and type. 148 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 149 unsigned sz, Type tp) { 150 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 151 Value a = constantIndex(rewriter, loc, sz); 152 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 153 } 154 155 /// Fills a temporary buffer of the given type with arguments. 156 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 157 ArrayRef<Value> values) { 158 unsigned sz = values.size(); 159 assert(sz >= 1); 160 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 161 for (unsigned i = 0; i < sz; i++) { 162 Value idx = constantIndex(rewriter, loc, i); 163 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 164 } 165 return buffer; 166 } 167 168 /// Generates a call into the "swiss army knife" method of the sparse runtime 169 /// support library for materializing sparse tensors into the computation. The 170 /// method returns the call value and assigns the permutation to 'perm'. 171 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 172 SparseTensorEncodingAttr &enc, uint32_t action, 173 Value &perm, ValueRange szs, Value ptr = Value()) { 174 Location loc = op->getLoc(); 175 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 176 SmallVector<Value, 8> params; 177 // Sparsity annotations in tensor constant form. 178 SmallVector<Value, 4> attrs; 179 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 180 unsigned sz = dlt.size(); 181 for (unsigned i = 0; i < sz; i++) 182 attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i]))); 183 params.push_back(genBuffer(rewriter, loc, attrs)); 184 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 185 // verification of external data, or for construction of internal data. 186 auto shape = resType.getShape(); 187 SmallVector<Value, 4> sizes; 188 if (szs.size() > 0) { 189 for (Value s : szs) 190 sizes.push_back( 191 rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type())); 192 } else { 193 for (unsigned i = 0; i < sz; i++) { 194 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 195 sizes.push_back(constantI64(rewriter, loc, s)); 196 } 197 } 198 params.push_back(genBuffer(rewriter, loc, sizes)); 199 // Dimension order permutation array. This is the "identity" permutation by 200 // default, or otherwise the "reverse" permutation of a given ordering, so 201 // that indices can be mapped quickly to the right position. 202 SmallVector<Value, 4> rev(sz); 203 if (AffineMap p = enc.getDimOrdering()) { 204 for (unsigned i = 0; i < sz; i++) 205 rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i); 206 } else { 207 for (unsigned i = 0; i < sz; i++) 208 rev[i] = constantI64(rewriter, loc, i); 209 } 210 perm = genBuffer(rewriter, loc, rev); 211 params.push_back(perm); 212 // Secondary and primary types encoding. 213 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 214 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 215 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 216 assert(primary); 217 params.push_back(constantI64(rewriter, loc, secPtr)); 218 params.push_back(constantI64(rewriter, loc, secInd)); 219 params.push_back(constantI64(rewriter, loc, primary)); 220 // User action and pointer. 221 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 222 if (!ptr) 223 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 224 params.push_back(constantI32(rewriter, loc, action)); 225 params.push_back(ptr); 226 // Generate the call to create new tensor. 227 StringRef name = "newSparseTensor"; 228 auto call = rewriter.create<CallOp>( 229 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 230 params); 231 return call.getResult(0); 232 } 233 234 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 235 /// For floating types, we use the "unordered" comparator (i.e., returns 236 /// true if `v` is NaN). 237 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 238 Value v) { 239 Type t = v.getType(); 240 Value zero = constantZero(rewriter, loc, t); 241 if (t.isa<FloatType>()) 242 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 243 zero); 244 if (t.isIntOrIndex()) 245 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 246 zero); 247 llvm_unreachable("Unknown element type"); 248 } 249 250 /// Generates the code to read the value from tensor[ivs], and conditionally 251 /// stores the indices ivs to the memory in ind. The generated code looks like 252 /// the following and the insertion point after this routine is inside the 253 /// if-then branch behind the assignment to ind. This is to ensure that the 254 /// addEltX call generated after is inside the if-then branch. 255 /// if (tensor[ivs]!=0) { 256 /// ind = ivs 257 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 258 Location loc, Value tensor, Value ind, 259 ValueRange ivs) { 260 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 261 Value cond = genIsNonzero(rewriter, loc, val); 262 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 263 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 264 unsigned i = 0; 265 for (auto iv : ivs) { 266 Value idx = constantIndex(rewriter, loc, i++); 267 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 268 } 269 return val; 270 } 271 272 /// Generates a call that adds one element to a coordinate scheme. 273 /// In particular, this generates code like the following: 274 /// val = a[i1,..,ik]; 275 /// if val != 0 276 /// t->add(val, [i1,..,ik], [p1,..,pk]); 277 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 278 Type eltType, Value ptr, Value val, Value ind, 279 Value perm) { 280 Location loc = op->getLoc(); 281 StringRef name; 282 if (eltType.isF64()) 283 name = "addEltF64"; 284 else if (eltType.isF32()) 285 name = "addEltF32"; 286 else if (eltType.isInteger(64)) 287 name = "addEltI64"; 288 else if (eltType.isInteger(32)) 289 name = "addEltI32"; 290 else if (eltType.isInteger(16)) 291 name = "addEltI16"; 292 else if (eltType.isInteger(8)) 293 name = "addEltI8"; 294 else 295 llvm_unreachable("Unknown element type"); 296 SmallVector<Value, 8> params; 297 params.push_back(ptr); 298 params.push_back(val); 299 params.push_back(ind); 300 params.push_back(perm); 301 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 302 rewriter.create<CallOp>( 303 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 304 params); 305 } 306 307 /// If the tensor is a sparse constant, generates and returns the pair of 308 /// the constants for the indices and the values. 309 static Optional<std::pair<Value, Value>> 310 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 311 Value tensor) { 312 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 313 if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) { 314 DenseElementsAttr indicesAttr = attr.getIndices(); 315 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 316 DenseElementsAttr valuesAttr = attr.getValues(); 317 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 318 return std::make_pair(indices, values); 319 } 320 } 321 return {}; 322 } 323 324 /// Generates the code to copy the index at indices[ivs] to ind, and return 325 /// the value at value[ivs]. 326 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 327 Location loc, Value indices, 328 Value values, Value ind, ValueRange ivs, 329 unsigned rank) { 330 for (unsigned i = 0; i < rank; i++) { 331 Value idx = constantIndex(rewriter, loc, i); 332 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 333 ValueRange{ivs[0], idx}); 334 val = 335 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 336 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 337 } 338 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // Conversion rules. 343 //===----------------------------------------------------------------------===// 344 345 /// Sparse conversion rule for returns. 346 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 347 public: 348 using OpConversionPattern::OpConversionPattern; 349 LogicalResult 350 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 353 return success(); 354 } 355 }; 356 357 /// Sparse conversion rule for dimension accesses. 358 class SparseTensorToDimSizeConverter 359 : public OpConversionPattern<tensor::DimOp> { 360 public: 361 using OpConversionPattern::OpConversionPattern; 362 LogicalResult 363 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 364 ConversionPatternRewriter &rewriter) const override { 365 Type resType = op.getType(); 366 auto enc = getSparseTensorEncoding(op.source().getType()); 367 if (!enc) 368 return failure(); 369 // Permute the dim index. 370 Optional<int64_t> index = op.getConstantIndex(); 371 if (!index.hasValue()) 372 return failure(); 373 int64_t idx = index.getValue(); 374 if (AffineMap p = enc.getDimOrdering()) 375 idx = p.getPermutedPosition(idx); 376 // Generate the call. 377 StringRef name = "sparseDimSize"; 378 SmallVector<Value, 2> params; 379 params.push_back(adaptor.getOperands()[0]); 380 params.push_back(constantIndex(rewriter, op.getLoc(), idx)); 381 rewriter.replaceOpWithNewOp<CallOp>( 382 op, resType, getFunc(op, name, resType, params), params); 383 return success(); 384 } 385 }; 386 387 /// Sparse conversion rule for the new operator. 388 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 389 using OpConversionPattern::OpConversionPattern; 390 LogicalResult 391 matchAndRewrite(NewOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 Type resType = op.getType(); 394 auto enc = getSparseTensorEncoding(resType); 395 if (!enc) 396 return failure(); 397 Value perm; 398 rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {}, 399 adaptor.getOperands()[0])); 400 return success(); 401 } 402 }; 403 404 /// Sparse conversion rule for the init operator. 405 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 406 using OpConversionPattern::OpConversionPattern; 407 LogicalResult 408 matchAndRewrite(InitOp op, OpAdaptor adaptor, 409 ConversionPatternRewriter &rewriter) const override { 410 Type resType = op.getType(); 411 auto enc = getSparseTensorEncoding(resType); 412 if (!enc) 413 return failure(); 414 Value perm; 415 rewriter.replaceOp( 416 op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands())); 417 return success(); 418 } 419 }; 420 421 /// Sparse conversion rule for the convert operator. 422 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 423 using OpConversionPattern::OpConversionPattern; 424 LogicalResult 425 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 426 ConversionPatternRewriter &rewriter) const override { 427 Type resType = op.getType(); 428 auto encDst = getSparseTensorEncoding(resType); 429 auto encSrc = getSparseTensorEncoding(op.source().getType()); 430 auto src = adaptor.getOperands()[0]; 431 if (encDst && encSrc) { 432 // This is a sparse => sparse conversion, which is handled as follows: 433 // t = src->toCOO(); ; src to COO in dst order 434 // dst = newSparseTensor(t) 435 // Using the coordinate scheme as an intermediate does not always 436 // yield the fastest conversion but avoids the need for a full 437 // O(N^2) conversion matrix. 438 Value perm; 439 Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src); 440 rewriter.replaceOp( 441 op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo)); 442 return success(); 443 } 444 if (!encDst || encSrc) { 445 // TODO: sparse => dense 446 return failure(); 447 } 448 // This is a dense => sparse conversion or a sparse constant in COO => 449 // sparse conversion, which is handled as follows: 450 // t = newSparseCOO() 451 // ...code to fill the COO tensor t... 452 // s = newSparseTensor(t) 453 // 454 // To fill the COO tensor from a dense tensor: 455 // for i1 in dim1 456 // .. 457 // for ik in dimk 458 // val = a[i1,..,ik] 459 // if val != 0 460 // t->add(val, [i1,..,ik], [p1,..,pk]) 461 // 462 // To fill the COO tensor from a sparse constant in COO format: 463 // for i in range(NNZ) 464 // val = values[i] 465 // [i1,..,ik] = indices[i] 466 // t->add(val, [i1,..,ik], [p1,..,pk]) 467 // 468 // Note that the dense tensor traversal code is actually implemented 469 // using MLIR IR to avoid having to expose too much low-level 470 // memref traversal details to the runtime support library. 471 // Also note that the code below only generates the "new" ops and 472 // the loop-nest per se; whereas the entire body of the innermost 473 // loop is generated by genAddElt(). 474 Location loc = op->getLoc(); 475 ShapedType shape = resType.cast<ShapedType>(); 476 Value perm; 477 Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {}); 478 Value ind = 479 genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType()); 480 SmallVector<Value> lo; 481 SmallVector<Value> hi; 482 SmallVector<Value> st; 483 Value zero = constantIndex(rewriter, loc, 0); 484 Value one = constantIndex(rewriter, loc, 1); 485 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 486 bool isCOOConstant = indicesValues.hasValue(); 487 Value indices; 488 Value values; 489 if (isCOOConstant) { 490 indices = indicesValues->first; 491 values = indicesValues->second; 492 lo.push_back(zero); 493 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 494 st.push_back(one); 495 } else { 496 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 497 lo.push_back(zero); 498 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 499 st.push_back(one); 500 } 501 } 502 Type eltType = shape.getElementType(); 503 unsigned rank = shape.getRank(); 504 scf::buildLoopNest( 505 rewriter, op.getLoc(), lo, hi, st, {}, 506 [&](OpBuilder &builder, Location loc, ValueRange ivs, 507 ValueRange args) -> scf::ValueVector { 508 Value val; 509 if (isCOOConstant) 510 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 511 ivs, rank); 512 else 513 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 514 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 515 return {}; 516 }); 517 rewriter.replaceOp( 518 op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr)); 519 return success(); 520 } 521 }; 522 523 /// Sparse conversion rule for the release operator. 524 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 525 public: 526 using OpConversionPattern::OpConversionPattern; 527 LogicalResult 528 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 529 ConversionPatternRewriter &rewriter) const override { 530 StringRef name = "delSparseTensor"; 531 TypeRange none; 532 rewriter.create<CallOp>(op.getLoc(), none, 533 getFunc(op, name, none, adaptor.getOperands()), 534 adaptor.getOperands()); 535 rewriter.eraseOp(op); 536 return success(); 537 } 538 }; 539 540 /// Sparse conversion rule for pointer accesses. 541 class SparseTensorToPointersConverter 542 : public OpConversionPattern<ToPointersOp> { 543 public: 544 using OpConversionPattern::OpConversionPattern; 545 LogicalResult 546 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 547 ConversionPatternRewriter &rewriter) const override { 548 Type resType = op.getType(); 549 Type eltType = resType.cast<ShapedType>().getElementType(); 550 StringRef name; 551 if (eltType.isIndex()) 552 name = "sparsePointers"; // 64-bit, but its own name for unique signature 553 else if (eltType.isInteger(64)) 554 name = "sparsePointers64"; 555 else if (eltType.isInteger(32)) 556 name = "sparsePointers32"; 557 else if (eltType.isInteger(16)) 558 name = "sparsePointers16"; 559 else if (eltType.isInteger(8)) 560 name = "sparsePointers8"; 561 else 562 return failure(); 563 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 564 getFunc(op, name, resType, 565 adaptor.getOperands(), 566 /*emitCInterface=*/true), 567 adaptor.getOperands()); 568 return success(); 569 } 570 }; 571 572 /// Sparse conversion rule for index accesses. 573 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 574 public: 575 using OpConversionPattern::OpConversionPattern; 576 LogicalResult 577 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 Type resType = op.getType(); 580 Type eltType = resType.cast<ShapedType>().getElementType(); 581 StringRef name; 582 if (eltType.isIndex()) 583 name = "sparseIndices"; // 64-bit, but its own name for unique signature 584 else if (eltType.isInteger(64)) 585 name = "sparseIndices64"; 586 else if (eltType.isInteger(32)) 587 name = "sparseIndices32"; 588 else if (eltType.isInteger(16)) 589 name = "sparseIndices16"; 590 else if (eltType.isInteger(8)) 591 name = "sparseIndices8"; 592 else 593 return failure(); 594 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 595 getFunc(op, name, resType, 596 adaptor.getOperands(), 597 /*emitCInterface=*/true), 598 adaptor.getOperands()); 599 return success(); 600 } 601 }; 602 603 /// Sparse conversion rule for value accesses. 604 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 605 public: 606 using OpConversionPattern::OpConversionPattern; 607 LogicalResult 608 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 609 ConversionPatternRewriter &rewriter) const override { 610 Type resType = op.getType(); 611 Type eltType = resType.cast<ShapedType>().getElementType(); 612 StringRef name; 613 if (eltType.isF64()) 614 name = "sparseValuesF64"; 615 else if (eltType.isF32()) 616 name = "sparseValuesF32"; 617 else if (eltType.isInteger(64)) 618 name = "sparseValuesI64"; 619 else if (eltType.isInteger(32)) 620 name = "sparseValuesI32"; 621 else if (eltType.isInteger(16)) 622 name = "sparseValuesI16"; 623 else if (eltType.isInteger(8)) 624 name = "sparseValuesI8"; 625 else 626 return failure(); 627 rewriter.replaceOpWithNewOp<CallOp>(op, resType, 628 getFunc(op, name, resType, 629 adaptor.getOperands(), 630 /*emitCInterface=*/true), 631 adaptor.getOperands()); 632 return success(); 633 } 634 }; 635 636 /// Sparse conversion rule for tensor reconstruction. 637 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 638 public: 639 using OpConversionPattern::OpConversionPattern; 640 LogicalResult 641 // Simply fold the operator into the pointer to the sparse storage scheme. 642 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 643 ConversionPatternRewriter &rewriter) const override { 644 // Check that all arguments of the tensor reconstruction operators are calls 645 // into the support library that query exactly the same opaque pointer. 646 Value ptr; 647 for (Value op : adaptor.getOperands()) { 648 if (auto call = op.getDefiningOp<CallOp>()) { 649 Value arg = call.getOperand(0); 650 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 651 return failure(); 652 if (!ptr) 653 ptr = arg; 654 else if (arg != ptr) 655 return failure(); 656 } 657 } 658 // If a single opaque pointer is found, perform the folding. 659 if (!ptr) 660 return failure(); 661 rewriter.replaceOp(op, ptr); 662 return success(); 663 } 664 }; 665 666 } // namespace 667 668 //===----------------------------------------------------------------------===// 669 // Public method for populating conversion rules. 670 //===----------------------------------------------------------------------===// 671 672 /// Populates the given patterns list with conversion rules required for 673 /// the sparsification of linear algebra operations. 674 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 675 RewritePatternSet &patterns) { 676 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 677 SparseTensorNewConverter, SparseTensorInitConverter, 678 SparseTensorConvertConverter, SparseTensorReleaseConverter, 679 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 680 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 681 typeConverter, patterns.getContext()); 682 } 683