1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 /// New tensor storage action. Keep these values consistent with 33 /// the sparse runtime support library. 34 enum Action : uint32_t { 35 kEmpty = 0, 36 kFromFile = 1, 37 kFromCOO = 2, 38 kEmptyCOO = 3, 39 kToCOO = 4 40 }; 41 42 //===----------------------------------------------------------------------===// 43 // Helper methods. 44 //===----------------------------------------------------------------------===// 45 46 /// Returns internal type encoding for primary storage. Keep these 47 /// values consistent with the sparse runtime support library. 48 static uint32_t getPrimaryTypeEncoding(Type tp) { 49 if (tp.isF64()) 50 return 1; 51 if (tp.isF32()) 52 return 2; 53 if (tp.isInteger(64)) 54 return 3; 55 if (tp.isInteger(32)) 56 return 4; 57 if (tp.isInteger(16)) 58 return 5; 59 if (tp.isInteger(8)) 60 return 6; 61 return 0; 62 } 63 64 /// Returns internal type encoding for overhead storage. Keep these 65 /// values consistent with the sparse runtime support library. 66 static uint32_t getOverheadTypeEncoding(unsigned width) { 67 switch (width) { 68 default: 69 return 1; 70 case 32: 71 return 2; 72 case 16: 73 return 3; 74 case 8: 75 return 4; 76 } 77 } 78 79 /// Returns internal dimension level type encoding. Keep these 80 /// values consistent with the sparse runtime support library. 81 static uint32_t 82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 83 switch (dlt) { 84 case SparseTensorEncodingAttr::DimLevelType::Dense: 85 return 0; 86 case SparseTensorEncodingAttr::DimLevelType::Compressed: 87 return 1; 88 case SparseTensorEncodingAttr::DimLevelType::Singleton: 89 return 2; 90 } 91 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 92 } 93 94 /// Generates a constant zero of the given type. 95 inline static Value constantZero(ConversionPatternRewriter &rewriter, 96 Location loc, Type t) { 97 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 98 } 99 100 /// Generates a constant of `index` type. 101 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 102 Location loc, int64_t i) { 103 return rewriter.create<arith::ConstantIndexOp>(loc, i); 104 } 105 106 /// Generates a constant of `i32` type. 107 inline static Value constantI32(ConversionPatternRewriter &rewriter, 108 Location loc, int32_t i) { 109 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 110 } 111 112 /// Generates a constant of `i8` type. 113 inline static Value constantI8(ConversionPatternRewriter &rewriter, 114 Location loc, int8_t i) { 115 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 116 } 117 118 /// Returns a function reference (first hit also inserts into module). Sets 119 /// the "_emit_c_interface" on the function declaration when requested, 120 /// so that LLVM lowering generates a wrapper function that takes care 121 /// of ABI complications with passing in and returning MemRefs to C functions. 122 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 123 TypeRange resultType, ValueRange operands, 124 bool emitCInterface = false) { 125 MLIRContext *context = op->getContext(); 126 auto module = op->getParentOfType<ModuleOp>(); 127 auto result = SymbolRefAttr::get(context, name); 128 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 129 if (!func) { 130 OpBuilder moduleBuilder(module.getBodyRegion()); 131 func = moduleBuilder.create<FuncOp>( 132 op->getLoc(), name, 133 FunctionType::get(context, operands.getTypes(), resultType)); 134 func.setPrivate(); 135 if (emitCInterface) 136 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 137 } 138 return result; 139 } 140 141 /// Generates dimension size call. 142 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 143 SparseTensorEncodingAttr &enc, Value src, 144 int64_t idx) { 145 // Permute the index according to an optional dimension ordering. 146 if (AffineMap p = enc.getDimOrdering()) 147 idx = p.getPermutedPosition(idx); 148 // Generate the call. 149 Location loc = op->getLoc(); 150 StringRef name = "sparseDimSize"; 151 SmallVector<Value, 2> params; 152 params.push_back(src); 153 params.push_back(constantIndex(rewriter, loc, idx)); 154 Type iTp = rewriter.getIndexType(); 155 auto fn = getFunc(op, name, iTp, params); 156 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 157 } 158 159 /// Generates a call into the "swiss army knife" method of the sparse runtime 160 /// support library for materializing sparse tensors into the computation. 161 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 162 ArrayRef<Value> params) { 163 Location loc = op->getLoc(); 164 StringRef name = "newSparseTensor"; 165 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 166 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 167 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 168 return call.getResult(0); 169 } 170 171 /// Populates given sizes array from type. 172 static void sizesFromType(ConversionPatternRewriter &rewriter, 173 SmallVector<Value, 4> &sizes, Location loc, 174 ShapedType stp) { 175 auto shape = stp.getShape(); 176 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 177 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 178 sizes.push_back(constantIndex(rewriter, loc, s)); 179 } 180 } 181 182 /// Populates given sizes array from source. 183 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 184 SmallVector<Value, 4> &sizes, Location loc, 185 Value src) { 186 ShapedType stp = src.getType().cast<ShapedType>(); 187 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 188 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 189 } 190 191 /// Populates given sizes array from type (for static sizes) and from 192 /// an already converted into opague pointer source (for dynamic sizes). 193 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 194 SmallVector<Value, 4> &sizes, Operation *op, 195 SparseTensorEncodingAttr &enc, ShapedType stp, 196 Value src) { 197 auto shape = stp.getShape(); 198 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 199 if (shape[i] == ShapedType::kDynamicSize) 200 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 201 else 202 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 203 } 204 205 /// Generates a temporary buffer of the given size and type. 206 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 207 unsigned sz, Type tp) { 208 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 209 Value a = constantIndex(rewriter, loc, sz); 210 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 211 } 212 213 /// Generates a temporary buffer of the given type and given contents. 214 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 215 ArrayRef<Value> values) { 216 unsigned sz = values.size(); 217 assert(sz >= 1); 218 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 219 for (unsigned i = 0; i < sz; i++) { 220 Value idx = constantIndex(rewriter, loc, i); 221 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 222 } 223 return buffer; 224 } 225 226 /// Populates parameters required to call the "swiss army knife" method of the 227 /// sparse runtime support library for materializing sparse tensors into the 228 /// computation. 229 static void newParams(ConversionPatternRewriter &rewriter, 230 SmallVector<Value, 8> ¶ms, Operation *op, 231 SparseTensorEncodingAttr &enc, uint32_t action, 232 ValueRange szs, Value ptr = Value()) { 233 Location loc = op->getLoc(); 234 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 235 unsigned sz = dlt.size(); 236 // Sparsity annotations. 237 SmallVector<Value, 4> attrs; 238 for (unsigned i = 0; i < sz; i++) 239 attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i]))); 240 params.push_back(genBuffer(rewriter, loc, attrs)); 241 // Dimension sizes array of the enveloping tensor. Useful for either 242 // verification of external data, or for construction of internal data. 243 SmallVector<Value, 4> sizes; 244 for (Value s : szs) 245 sizes.push_back(s); 246 params.push_back(genBuffer(rewriter, loc, sizes)); 247 // Dimension order permutation array. This is the "identity" permutation by 248 // default, or otherwise the "reverse" permutation of a given ordering, so 249 // that indices can be mapped quickly to the right position. 250 SmallVector<Value, 4> rev(sz); 251 if (AffineMap p = enc.getDimOrdering()) { 252 for (unsigned i = 0; i < sz; i++) 253 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 254 } else { 255 for (unsigned i = 0; i < sz; i++) 256 rev[i] = constantIndex(rewriter, loc, i); 257 } 258 params.push_back(genBuffer(rewriter, loc, rev)); 259 // Secondary and primary types encoding. 260 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 261 uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 262 uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 263 uint32_t primary = getPrimaryTypeEncoding(resType.getElementType()); 264 assert(primary); 265 params.push_back(constantI32(rewriter, loc, secPtr)); 266 params.push_back(constantI32(rewriter, loc, secInd)); 267 params.push_back(constantI32(rewriter, loc, primary)); 268 // User action and pointer. 269 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 270 if (!ptr) 271 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 272 params.push_back(constantI32(rewriter, loc, action)); 273 params.push_back(ptr); 274 } 275 276 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 277 /// For floating types, we use the "unordered" comparator (i.e., returns 278 /// true if `v` is NaN). 279 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 280 Value v) { 281 Type t = v.getType(); 282 Value zero = constantZero(rewriter, loc, t); 283 if (t.isa<FloatType>()) 284 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 285 zero); 286 if (t.isIntOrIndex()) 287 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 288 zero); 289 llvm_unreachable("Unknown element type"); 290 } 291 292 /// Generates the code to read the value from tensor[ivs], and conditionally 293 /// stores the indices ivs to the memory in ind. The generated code looks like 294 /// the following and the insertion point after this routine is inside the 295 /// if-then branch behind the assignment to ind. This is to ensure that the 296 /// addEltX call generated after is inside the if-then branch. 297 /// if (tensor[ivs]!=0) { 298 /// ind = ivs 299 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 300 Location loc, Value tensor, Value ind, 301 ValueRange ivs) { 302 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 303 Value cond = genIsNonzero(rewriter, loc, val); 304 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 305 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 306 unsigned i = 0; 307 for (auto iv : ivs) { 308 Value idx = constantIndex(rewriter, loc, i++); 309 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 310 } 311 return val; 312 } 313 314 /// Generates a call that adds one element to a coordinate scheme. 315 /// In particular, this generates code like the following: 316 /// val = a[i1,..,ik]; 317 /// if val != 0 318 /// t->add(val, [i1,..,ik], [p1,..,pk]); 319 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 320 Type eltType, Value ptr, Value val, Value ind, 321 Value perm) { 322 Location loc = op->getLoc(); 323 StringRef name; 324 if (eltType.isF64()) 325 name = "addEltF64"; 326 else if (eltType.isF32()) 327 name = "addEltF32"; 328 else if (eltType.isInteger(64)) 329 name = "addEltI64"; 330 else if (eltType.isInteger(32)) 331 name = "addEltI32"; 332 else if (eltType.isInteger(16)) 333 name = "addEltI16"; 334 else if (eltType.isInteger(8)) 335 name = "addEltI8"; 336 else 337 llvm_unreachable("Unknown element type"); 338 SmallVector<Value, 8> params; 339 params.push_back(ptr); 340 params.push_back(val); 341 params.push_back(ind); 342 params.push_back(perm); 343 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 344 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 345 rewriter.create<CallOp>(loc, pTp, fn, params); 346 } 347 348 /// If the tensor is a sparse constant, generates and returns the pair of 349 /// the constants for the indices and the values. 350 static Optional<std::pair<Value, Value>> 351 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 352 Value tensor) { 353 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 354 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 355 DenseElementsAttr indicesAttr = attr.getIndices(); 356 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 357 DenseElementsAttr valuesAttr = attr.getValues(); 358 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 359 return std::make_pair(indices, values); 360 } 361 } 362 return {}; 363 } 364 365 /// Generates the code to copy the index at indices[ivs] to ind, and return 366 /// the value at value[ivs]. 367 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 368 Location loc, Value indices, 369 Value values, Value ind, ValueRange ivs, 370 unsigned rank) { 371 for (unsigned i = 0; i < rank; i++) { 372 Value idx = constantIndex(rewriter, loc, i); 373 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 374 ValueRange{ivs[0], idx}); 375 val = 376 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 377 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 378 } 379 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // Conversion rules. 384 //===----------------------------------------------------------------------===// 385 386 /// Sparse conversion rule for returns. 387 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 388 public: 389 using OpConversionPattern::OpConversionPattern; 390 LogicalResult 391 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 394 return success(); 395 } 396 }; 397 398 /// Sparse conversion rule for dimension accesses. 399 class SparseTensorToDimSizeConverter 400 : public OpConversionPattern<tensor::DimOp> { 401 public: 402 using OpConversionPattern::OpConversionPattern; 403 LogicalResult 404 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 405 ConversionPatternRewriter &rewriter) const override { 406 // Only rewrite annotated DimOp with constant index. 407 auto enc = getSparseTensorEncoding(op.source().getType()); 408 if (!enc) 409 return failure(); 410 Optional<int64_t> index = op.getConstantIndex(); 411 if (!index.hasValue()) 412 return failure(); 413 // Generate the call. 414 Value src = adaptor.getOperands()[0]; 415 int64_t idx = index.getValue(); 416 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 417 return success(); 418 } 419 }; 420 421 /// Sparse conversion rule for trivial tensor casts. 422 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 423 using OpConversionPattern::OpConversionPattern; 424 LogicalResult 425 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 426 ConversionPatternRewriter &rewriter) const override { 427 // Only rewrite identically annotated source/dest. 428 auto encDst = getSparseTensorEncoding(op.getType()); 429 auto encSrc = getSparseTensorEncoding(op.source().getType()); 430 if (!encDst || encDst != encSrc) 431 return failure(); 432 rewriter.replaceOp(op, adaptor.getOperands()); 433 return success(); 434 } 435 }; 436 437 /// Sparse conversion rule for the new operator. 438 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 439 using OpConversionPattern::OpConversionPattern; 440 LogicalResult 441 matchAndRewrite(NewOp op, OpAdaptor adaptor, 442 ConversionPatternRewriter &rewriter) const override { 443 Type resType = op.getType(); 444 auto enc = getSparseTensorEncoding(resType); 445 if (!enc) 446 return failure(); 447 // Generate the call to construct tensor from ptr. The sizes are 448 // inferred from the result type of the new operator. 449 SmallVector<Value, 4> sizes; 450 SmallVector<Value, 8> params; 451 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 452 Value ptr = adaptor.getOperands()[0]; 453 newParams(rewriter, params, op, enc, kFromFile, sizes, ptr); 454 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 455 return success(); 456 } 457 }; 458 459 /// Sparse conversion rule for the init operator. 460 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 461 using OpConversionPattern::OpConversionPattern; 462 LogicalResult 463 matchAndRewrite(InitOp op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 Type resType = op.getType(); 466 auto enc = getSparseTensorEncoding(resType); 467 if (!enc) 468 return failure(); 469 // Generate the call to construct empty tensor. The sizes are 470 // explicitly defined by the arguments to the init operator. 471 SmallVector<Value, 8> params; 472 newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands()); 473 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 474 return success(); 475 } 476 }; 477 478 /// Sparse conversion rule for the convert operator. 479 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 480 using OpConversionPattern::OpConversionPattern; 481 LogicalResult 482 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 483 ConversionPatternRewriter &rewriter) const override { 484 Location loc = op->getLoc(); 485 Type resType = op.getType(); 486 Type srcType = op.source().getType(); 487 auto encDst = getSparseTensorEncoding(resType); 488 auto encSrc = getSparseTensorEncoding(srcType); 489 Value src = adaptor.getOperands()[0]; 490 if (encDst && encSrc) { 491 // This is a sparse => sparse conversion, which is handled as follows: 492 // t = src->toCOO(); ; src to COO in dst order 493 // dst = newSparseTensor(t) 494 // Using the coordinate scheme as an intermediate does not always 495 // yield the fastest conversion but avoids the need for a full 496 // O(N^2) conversion matrix. 497 if (encDst == encSrc) { 498 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 499 return success(); 500 } 501 SmallVector<Value, 4> sizes; 502 SmallVector<Value, 8> params; 503 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 504 src); 505 newParams(rewriter, params, op, encDst, kToCOO, sizes, src); 506 Value coo = genNewCall(rewriter, op, params); 507 params[6] = constantI32(rewriter, loc, kFromCOO); 508 params[7] = coo; 509 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 510 return success(); 511 } 512 if (!encDst || encSrc) { 513 // TODO: sparse => dense 514 return failure(); 515 } 516 // This is a dense => sparse conversion or a sparse constant in COO => 517 // sparse conversion, which is handled as follows: 518 // t = newSparseCOO() 519 // ...code to fill the COO tensor t... 520 // s = newSparseTensor(t) 521 // 522 // To fill the COO tensor from a dense tensor: 523 // for i1 in dim1 524 // .. 525 // for ik in dimk 526 // val = a[i1,..,ik] 527 // if val != 0 528 // t->add(val, [i1,..,ik], [p1,..,pk]) 529 // 530 // To fill the COO tensor from a sparse constant in COO format: 531 // for i in range(NNZ) 532 // val = values[i] 533 // [i1,..,ik] = indices[i] 534 // t->add(val, [i1,..,ik], [p1,..,pk]) 535 // 536 // Note that the dense tensor traversal code is actually implemented 537 // using MLIR IR to avoid having to expose too much low-level 538 // memref traversal details to the runtime support library. 539 // Also note that the code below only generates the "new" ops and 540 // the loop-nest per se; whereas the entire body of the innermost 541 // loop is generated by genAddElt(). 542 ShapedType stp = resType.cast<ShapedType>(); 543 unsigned rank = stp.getRank(); 544 SmallVector<Value, 4> sizes; 545 SmallVector<Value, 8> params; 546 sizesFromSrc(rewriter, sizes, loc, src); 547 newParams(rewriter, params, op, encDst, kEmptyCOO, sizes); 548 Value ptr = genNewCall(rewriter, op, params); 549 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 550 Value perm = params[2]; 551 SmallVector<Value> lo; 552 SmallVector<Value> hi; 553 SmallVector<Value> st; 554 Value zero = constantIndex(rewriter, loc, 0); 555 Value one = constantIndex(rewriter, loc, 1); 556 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 557 bool isCOOConstant = indicesValues.hasValue(); 558 Value indices; 559 Value values; 560 if (isCOOConstant) { 561 indices = indicesValues->first; 562 values = indicesValues->second; 563 lo.push_back(zero); 564 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 565 st.push_back(one); 566 } else { 567 for (unsigned i = 0; i < rank; i++) { 568 lo.push_back(zero); 569 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 570 st.push_back(one); 571 } 572 } 573 Type eltType = stp.getElementType(); 574 scf::buildLoopNest( 575 rewriter, op.getLoc(), lo, hi, st, {}, 576 [&](OpBuilder &builder, Location loc, ValueRange ivs, 577 ValueRange args) -> scf::ValueVector { 578 Value val; 579 if (isCOOConstant) 580 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 581 ivs, rank); 582 else 583 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 584 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 585 return {}; 586 }); 587 // Final call to construct sparse tensor storage. 588 params[6] = constantI32(rewriter, loc, kFromCOO); 589 params[7] = ptr; 590 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 591 return success(); 592 } 593 }; 594 595 /// Sparse conversion rule for the release operator. 596 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 597 public: 598 using OpConversionPattern::OpConversionPattern; 599 LogicalResult 600 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 601 ConversionPatternRewriter &rewriter) const override { 602 StringRef name = "delSparseTensor"; 603 TypeRange none; 604 auto fn = getFunc(op, name, none, adaptor.getOperands()); 605 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 606 rewriter.eraseOp(op); 607 return success(); 608 } 609 }; 610 611 /// Sparse conversion rule for pointer accesses. 612 class SparseTensorToPointersConverter 613 : public OpConversionPattern<ToPointersOp> { 614 public: 615 using OpConversionPattern::OpConversionPattern; 616 LogicalResult 617 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 618 ConversionPatternRewriter &rewriter) const override { 619 Type resType = op.getType(); 620 Type eltType = resType.cast<ShapedType>().getElementType(); 621 StringRef name; 622 if (eltType.isIndex()) 623 name = "sparsePointers"; 624 else if (eltType.isInteger(64)) 625 name = "sparsePointers64"; 626 else if (eltType.isInteger(32)) 627 name = "sparsePointers32"; 628 else if (eltType.isInteger(16)) 629 name = "sparsePointers16"; 630 else if (eltType.isInteger(8)) 631 name = "sparsePointers8"; 632 else 633 return failure(); 634 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 635 /*emitCInterface=*/true); 636 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 637 return success(); 638 } 639 }; 640 641 /// Sparse conversion rule for index accesses. 642 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 643 public: 644 using OpConversionPattern::OpConversionPattern; 645 LogicalResult 646 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 647 ConversionPatternRewriter &rewriter) const override { 648 Type resType = op.getType(); 649 Type eltType = resType.cast<ShapedType>().getElementType(); 650 StringRef name; 651 if (eltType.isIndex()) 652 name = "sparseIndices"; 653 else if (eltType.isInteger(64)) 654 name = "sparseIndices64"; 655 else if (eltType.isInteger(32)) 656 name = "sparseIndices32"; 657 else if (eltType.isInteger(16)) 658 name = "sparseIndices16"; 659 else if (eltType.isInteger(8)) 660 name = "sparseIndices8"; 661 else 662 return failure(); 663 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 664 /*emitCInterface=*/true); 665 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 666 return success(); 667 } 668 }; 669 670 /// Sparse conversion rule for value accesses. 671 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 672 public: 673 using OpConversionPattern::OpConversionPattern; 674 LogicalResult 675 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 676 ConversionPatternRewriter &rewriter) const override { 677 Type resType = op.getType(); 678 Type eltType = resType.cast<ShapedType>().getElementType(); 679 StringRef name; 680 if (eltType.isF64()) 681 name = "sparseValuesF64"; 682 else if (eltType.isF32()) 683 name = "sparseValuesF32"; 684 else if (eltType.isInteger(64)) 685 name = "sparseValuesI64"; 686 else if (eltType.isInteger(32)) 687 name = "sparseValuesI32"; 688 else if (eltType.isInteger(16)) 689 name = "sparseValuesI16"; 690 else if (eltType.isInteger(8)) 691 name = "sparseValuesI8"; 692 else 693 return failure(); 694 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 695 /*emitCInterface=*/true); 696 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 697 return success(); 698 } 699 }; 700 701 /// Sparse conversion rule for tensor reconstruction. 702 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 703 public: 704 using OpConversionPattern::OpConversionPattern; 705 LogicalResult 706 // Simply fold the operator into the pointer to the sparse storage scheme. 707 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 708 ConversionPatternRewriter &rewriter) const override { 709 // Check that all arguments of the tensor reconstruction operators are calls 710 // into the support library that query exactly the same opaque pointer. 711 Value ptr; 712 for (Value op : adaptor.getOperands()) { 713 if (auto call = op.getDefiningOp<CallOp>()) { 714 Value arg = call.getOperand(0); 715 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 716 return failure(); 717 if (!ptr) 718 ptr = arg; 719 else if (arg != ptr) 720 return failure(); 721 } 722 } 723 // If a single opaque pointer is found, perform the folding. 724 if (!ptr) 725 return failure(); 726 rewriter.replaceOp(op, ptr); 727 return success(); 728 } 729 }; 730 731 } // namespace 732 733 //===----------------------------------------------------------------------===// 734 // Public method for populating conversion rules. 735 //===----------------------------------------------------------------------===// 736 737 /// Populates the given patterns list with conversion rules required for 738 /// the sparsification of linear algebra operations. 739 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 740 RewritePatternSet &patterns) { 741 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 742 SparseCastConverter, SparseTensorNewConverter, 743 SparseTensorInitConverter, SparseTensorConvertConverter, 744 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 745 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 746 SparseTensorToTensorConverter>(typeConverter, 747 patterns.getContext()); 748 } 749