1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 /// New tensor storage action. Keep these values consistent with 33 /// the sparse runtime support library. 34 enum Action : uint32_t { 35 kEmpty = 0, 36 kFromFile = 1, 37 kFromCOO = 2, 38 kEmptyCOO = 3, 39 kToCOO = 4 40 }; 41 42 //===----------------------------------------------------------------------===// 43 // Helper methods. 44 //===----------------------------------------------------------------------===// 45 46 /// Returns internal type encoding for primary storage. Keep these 47 /// values consistent with the sparse runtime support library. 48 static unsigned getPrimaryTypeEncoding(Type tp) { 49 if (tp.isF64()) 50 return 1; 51 if (tp.isF32()) 52 return 2; 53 if (tp.isInteger(64)) 54 return 3; 55 if (tp.isInteger(32)) 56 return 4; 57 if (tp.isInteger(16)) 58 return 5; 59 if (tp.isInteger(8)) 60 return 6; 61 return 0; 62 } 63 64 /// Returns internal type encoding for overhead storage. Keep these 65 /// values consistent with the sparse runtime support library. 66 static unsigned getOverheadTypeEncoding(unsigned width) { 67 switch (width) { 68 default: 69 return 1; 70 case 32: 71 return 2; 72 case 16: 73 return 3; 74 case 8: 75 return 4; 76 } 77 } 78 79 /// Returns internal dimension level type encoding. Keep these 80 /// values consistent with the sparse runtime support library. 81 static unsigned 82 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 83 switch (dlt) { 84 case SparseTensorEncodingAttr::DimLevelType::Dense: 85 return 0; 86 case SparseTensorEncodingAttr::DimLevelType::Compressed: 87 return 1; 88 case SparseTensorEncodingAttr::DimLevelType::Singleton: 89 return 2; 90 } 91 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 92 } 93 94 /// Generates a constant zero of the given type. 95 inline static Value constantZero(ConversionPatternRewriter &rewriter, 96 Location loc, Type t) { 97 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 98 } 99 100 /// Generates a constant of `index` type. 101 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 102 Location loc, int64_t i) { 103 return rewriter.create<arith::ConstantIndexOp>(loc, i); 104 } 105 106 /// Generates a constant of `i64` type. 107 inline static Value constantI64(ConversionPatternRewriter &rewriter, 108 Location loc, int64_t i) { 109 return rewriter.create<arith::ConstantIntOp>(loc, i, 64); 110 } 111 112 /// Generates a constant of `i32` type. 113 inline static Value constantI32(ConversionPatternRewriter &rewriter, 114 Location loc, int32_t i) { 115 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 116 } 117 118 /// Generates a constant of `i8` type. 119 inline static Value constantI8(ConversionPatternRewriter &rewriter, 120 Location loc, int8_t i) { 121 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 122 } 123 124 /// Returns a function reference (first hit also inserts into module). Sets 125 /// the "_emit_c_interface" on the function declaration when requested, 126 /// so that LLVM lowering generates a wrapper function that takes care 127 /// of ABI complications with passing in and returning MemRefs to C functions. 128 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 129 TypeRange resultType, ValueRange operands, 130 bool emitCInterface = false) { 131 MLIRContext *context = op->getContext(); 132 auto module = op->getParentOfType<ModuleOp>(); 133 auto result = SymbolRefAttr::get(context, name); 134 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 135 if (!func) { 136 OpBuilder moduleBuilder(module.getBodyRegion()); 137 func = moduleBuilder.create<FuncOp>( 138 op->getLoc(), name, 139 FunctionType::get(context, operands.getTypes(), resultType)); 140 func.setPrivate(); 141 if (emitCInterface) 142 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 143 } 144 return result; 145 } 146 147 /// Generates dimension size call. 148 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 149 SparseTensorEncodingAttr &enc, Value src, 150 int64_t idx) { 151 // Permute the index according to an optional dimension ordering. 152 if (AffineMap p = enc.getDimOrdering()) 153 idx = p.getPermutedPosition(idx); 154 // Generate the call. 155 Location loc = op->getLoc(); 156 StringRef name = "sparseDimSize"; 157 SmallVector<Value, 2> params; 158 params.push_back(src); 159 params.push_back(constantIndex(rewriter, loc, idx)); 160 Type iTp = rewriter.getIndexType(); 161 auto fn = getFunc(op, name, iTp, params); 162 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 163 } 164 165 /// Generates a call into the "swiss army knife" method of the sparse runtime 166 /// support library for materializing sparse tensors into the computation. 167 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 168 ArrayRef<Value> params) { 169 Location loc = op->getLoc(); 170 StringRef name = "newSparseTensor"; 171 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 172 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 173 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 174 return call.getResult(0); 175 } 176 177 /// Populates given sizes array from type. 178 static void sizesFromType(ConversionPatternRewriter &rewriter, 179 SmallVector<Value, 4> &sizes, Location loc, 180 ShapedType stp) { 181 auto shape = stp.getShape(); 182 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 183 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 184 sizes.push_back(constantIndex(rewriter, loc, s)); 185 } 186 } 187 188 /// Populates given sizes array from source. 189 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 190 SmallVector<Value, 4> &sizes, Location loc, 191 Value src) { 192 ShapedType stp = src.getType().cast<ShapedType>(); 193 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 194 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 195 } 196 197 /// Populates given sizes array from type (for static sizes) and from 198 /// an already converted into opague pointer source (for dynamic sizes). 199 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 200 SmallVector<Value, 4> &sizes, Operation *op, 201 SparseTensorEncodingAttr &enc, ShapedType stp, 202 Value src) { 203 auto shape = stp.getShape(); 204 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 205 if (shape[i] == ShapedType::kDynamicSize) 206 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 207 else 208 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 209 } 210 211 /// Generates a temporary buffer of the given size and type. 212 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 213 unsigned sz, Type tp) { 214 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 215 Value a = constantIndex(rewriter, loc, sz); 216 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 217 } 218 219 /// Generates a temporary buffer of the given type and given contents. 220 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 221 ArrayRef<Value> values) { 222 unsigned sz = values.size(); 223 assert(sz >= 1); 224 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 225 for (unsigned i = 0; i < sz; i++) { 226 Value idx = constantIndex(rewriter, loc, i); 227 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 228 } 229 return buffer; 230 } 231 232 /// Populates parameters required to call the "swiss army knife" method of the 233 /// sparse runtime support library for materializing sparse tensors into the 234 /// computation. 235 static void newParams(ConversionPatternRewriter &rewriter, 236 SmallVector<Value, 8> ¶ms, Operation *op, 237 SparseTensorEncodingAttr &enc, uint32_t action, 238 ValueRange szs, Value ptr = Value()) { 239 Location loc = op->getLoc(); 240 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 241 unsigned sz = dlt.size(); 242 // Sparsity annotations. 243 SmallVector<Value, 4> attrs; 244 for (unsigned i = 0; i < sz; i++) 245 attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i]))); 246 params.push_back(genBuffer(rewriter, loc, attrs)); 247 // Dimension sizes array of the enveloping tensor. Useful for either 248 // verification of external data, or for construction of internal data. 249 // The index type is casted to I64 for API consistency. 250 Type iTp = rewriter.getI64Type(); 251 SmallVector<Value, 4> sizes; 252 for (Value s : szs) 253 sizes.push_back(rewriter.create<arith::IndexCastOp>(loc, s, iTp)); 254 params.push_back(genBuffer(rewriter, loc, sizes)); 255 // Dimension order permutation array. This is the "identity" permutation by 256 // default, or otherwise the "reverse" permutation of a given ordering, so 257 // that indices can be mapped quickly to the right position. 258 SmallVector<Value, 4> rev(sz); 259 if (AffineMap p = enc.getDimOrdering()) { 260 for (unsigned i = 0; i < sz; i++) 261 rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i); 262 } else { 263 for (unsigned i = 0; i < sz; i++) 264 rev[i] = constantI64(rewriter, loc, i); 265 } 266 params.push_back(genBuffer(rewriter, loc, rev)); 267 // Secondary and primary types encoding. 268 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 269 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 270 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 271 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 272 assert(primary); 273 params.push_back(constantI64(rewriter, loc, secPtr)); 274 params.push_back(constantI64(rewriter, loc, secInd)); 275 params.push_back(constantI64(rewriter, loc, primary)); 276 // User action and pointer. 277 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 278 if (!ptr) 279 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 280 params.push_back(constantI32(rewriter, loc, action)); 281 params.push_back(ptr); 282 } 283 284 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 285 /// For floating types, we use the "unordered" comparator (i.e., returns 286 /// true if `v` is NaN). 287 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 288 Value v) { 289 Type t = v.getType(); 290 Value zero = constantZero(rewriter, loc, t); 291 if (t.isa<FloatType>()) 292 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 293 zero); 294 if (t.isIntOrIndex()) 295 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 296 zero); 297 llvm_unreachable("Unknown element type"); 298 } 299 300 /// Generates the code to read the value from tensor[ivs], and conditionally 301 /// stores the indices ivs to the memory in ind. The generated code looks like 302 /// the following and the insertion point after this routine is inside the 303 /// if-then branch behind the assignment to ind. This is to ensure that the 304 /// addEltX call generated after is inside the if-then branch. 305 /// if (tensor[ivs]!=0) { 306 /// ind = ivs 307 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 308 Location loc, Value tensor, Value ind, 309 ValueRange ivs) { 310 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 311 Value cond = genIsNonzero(rewriter, loc, val); 312 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 313 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 314 unsigned i = 0; 315 for (auto iv : ivs) { 316 Value idx = constantIndex(rewriter, loc, i++); 317 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 318 } 319 return val; 320 } 321 322 /// Generates a call that adds one element to a coordinate scheme. 323 /// In particular, this generates code like the following: 324 /// val = a[i1,..,ik]; 325 /// if val != 0 326 /// t->add(val, [i1,..,ik], [p1,..,pk]); 327 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 328 Type eltType, Value ptr, Value val, Value ind, 329 Value perm) { 330 Location loc = op->getLoc(); 331 StringRef name; 332 if (eltType.isF64()) 333 name = "addEltF64"; 334 else if (eltType.isF32()) 335 name = "addEltF32"; 336 else if (eltType.isInteger(64)) 337 name = "addEltI64"; 338 else if (eltType.isInteger(32)) 339 name = "addEltI32"; 340 else if (eltType.isInteger(16)) 341 name = "addEltI16"; 342 else if (eltType.isInteger(8)) 343 name = "addEltI8"; 344 else 345 llvm_unreachable("Unknown element type"); 346 SmallVector<Value, 8> params; 347 params.push_back(ptr); 348 params.push_back(val); 349 params.push_back(ind); 350 params.push_back(perm); 351 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 352 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 353 rewriter.create<CallOp>(loc, pTp, fn, params); 354 } 355 356 /// If the tensor is a sparse constant, generates and returns the pair of 357 /// the constants for the indices and the values. 358 static Optional<std::pair<Value, Value>> 359 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 360 Value tensor) { 361 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 362 if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) { 363 DenseElementsAttr indicesAttr = attr.getIndices(); 364 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 365 DenseElementsAttr valuesAttr = attr.getValues(); 366 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 367 return std::make_pair(indices, values); 368 } 369 } 370 return {}; 371 } 372 373 /// Generates the code to copy the index at indices[ivs] to ind, and return 374 /// the value at value[ivs]. 375 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 376 Location loc, Value indices, 377 Value values, Value ind, ValueRange ivs, 378 unsigned rank) { 379 for (unsigned i = 0; i < rank; i++) { 380 Value idx = constantIndex(rewriter, loc, i); 381 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 382 ValueRange{ivs[0], idx}); 383 val = 384 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 385 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 386 } 387 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // Conversion rules. 392 //===----------------------------------------------------------------------===// 393 394 /// Sparse conversion rule for returns. 395 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 396 public: 397 using OpConversionPattern::OpConversionPattern; 398 LogicalResult 399 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 400 ConversionPatternRewriter &rewriter) const override { 401 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 402 return success(); 403 } 404 }; 405 406 /// Sparse conversion rule for dimension accesses. 407 class SparseTensorToDimSizeConverter 408 : public OpConversionPattern<tensor::DimOp> { 409 public: 410 using OpConversionPattern::OpConversionPattern; 411 LogicalResult 412 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 413 ConversionPatternRewriter &rewriter) const override { 414 // Only rewrite annotated DimOp with constant index. 415 auto enc = getSparseTensorEncoding(op.source().getType()); 416 if (!enc) 417 return failure(); 418 Optional<int64_t> index = op.getConstantIndex(); 419 if (!index.hasValue()) 420 return failure(); 421 // Generate the call. 422 Value src = adaptor.getOperands()[0]; 423 int64_t idx = index.getValue(); 424 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 425 return success(); 426 } 427 }; 428 429 /// Sparse conversion rule for the new operator. 430 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 431 using OpConversionPattern::OpConversionPattern; 432 LogicalResult 433 matchAndRewrite(NewOp op, OpAdaptor adaptor, 434 ConversionPatternRewriter &rewriter) const override { 435 Type resType = op.getType(); 436 auto enc = getSparseTensorEncoding(resType); 437 if (!enc) 438 return failure(); 439 // Generate the call to construct tensor from ptr. The sizes are 440 // inferred from the result type of the new operator. 441 SmallVector<Value, 4> sizes; 442 SmallVector<Value, 8> params; 443 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 444 Value ptr = adaptor.getOperands()[0]; 445 newParams(rewriter, params, op, enc, kFromFile, sizes, ptr); 446 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 447 return success(); 448 } 449 }; 450 451 /// Sparse conversion rule for the init operator. 452 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 453 using OpConversionPattern::OpConversionPattern; 454 LogicalResult 455 matchAndRewrite(InitOp op, OpAdaptor adaptor, 456 ConversionPatternRewriter &rewriter) const override { 457 Type resType = op.getType(); 458 auto enc = getSparseTensorEncoding(resType); 459 if (!enc) 460 return failure(); 461 // Generate the call to construct empty tensor. The sizes are 462 // explicitly defined by the arguments to the init operator. 463 SmallVector<Value, 8> params; 464 newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands()); 465 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 466 return success(); 467 } 468 }; 469 470 /// Sparse conversion rule for the convert operator. 471 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 472 using OpConversionPattern::OpConversionPattern; 473 LogicalResult 474 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 475 ConversionPatternRewriter &rewriter) const override { 476 Location loc = op->getLoc(); 477 Type resType = op.getType(); 478 Type srcType = op.source().getType(); 479 auto encDst = getSparseTensorEncoding(resType); 480 auto encSrc = getSparseTensorEncoding(srcType); 481 Value src = adaptor.getOperands()[0]; 482 if (encDst && encSrc) { 483 // This is a sparse => sparse conversion, which is handled as follows: 484 // t = src->toCOO(); ; src to COO in dst order 485 // dst = newSparseTensor(t) 486 // Using the coordinate scheme as an intermediate does not always 487 // yield the fastest conversion but avoids the need for a full 488 // O(N^2) conversion matrix. 489 SmallVector<Value, 4> sizes; 490 SmallVector<Value, 8> params; 491 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 492 src); 493 newParams(rewriter, params, op, encDst, kToCOO, sizes, src); 494 Value coo = genNewCall(rewriter, op, params); 495 params[6] = constantI32(rewriter, loc, kFromCOO); 496 params[7] = coo; 497 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 498 return success(); 499 } 500 if (!encDst || encSrc) { 501 // TODO: sparse => dense 502 return failure(); 503 } 504 // This is a dense => sparse conversion or a sparse constant in COO => 505 // sparse conversion, which is handled as follows: 506 // t = newSparseCOO() 507 // ...code to fill the COO tensor t... 508 // s = newSparseTensor(t) 509 // 510 // To fill the COO tensor from a dense tensor: 511 // for i1 in dim1 512 // .. 513 // for ik in dimk 514 // val = a[i1,..,ik] 515 // if val != 0 516 // t->add(val, [i1,..,ik], [p1,..,pk]) 517 // 518 // To fill the COO tensor from a sparse constant in COO format: 519 // for i in range(NNZ) 520 // val = values[i] 521 // [i1,..,ik] = indices[i] 522 // t->add(val, [i1,..,ik], [p1,..,pk]) 523 // 524 // Note that the dense tensor traversal code is actually implemented 525 // using MLIR IR to avoid having to expose too much low-level 526 // memref traversal details to the runtime support library. 527 // Also note that the code below only generates the "new" ops and 528 // the loop-nest per se; whereas the entire body of the innermost 529 // loop is generated by genAddElt(). 530 ShapedType stp = resType.cast<ShapedType>(); 531 unsigned rank = stp.getRank(); 532 SmallVector<Value, 4> sizes; 533 SmallVector<Value, 8> params; 534 sizesFromSrc(rewriter, sizes, loc, src); 535 newParams(rewriter, params, op, encDst, kEmptyCOO, sizes); 536 Value ptr = genNewCall(rewriter, op, params); 537 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 538 Value perm = params[2]; 539 SmallVector<Value> lo; 540 SmallVector<Value> hi; 541 SmallVector<Value> st; 542 Value zero = constantIndex(rewriter, loc, 0); 543 Value one = constantIndex(rewriter, loc, 1); 544 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 545 bool isCOOConstant = indicesValues.hasValue(); 546 Value indices; 547 Value values; 548 if (isCOOConstant) { 549 indices = indicesValues->first; 550 values = indicesValues->second; 551 lo.push_back(zero); 552 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 553 st.push_back(one); 554 } else { 555 for (unsigned i = 0; i < rank; i++) { 556 lo.push_back(zero); 557 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 558 st.push_back(one); 559 } 560 } 561 Type eltType = stp.getElementType(); 562 scf::buildLoopNest( 563 rewriter, op.getLoc(), lo, hi, st, {}, 564 [&](OpBuilder &builder, Location loc, ValueRange ivs, 565 ValueRange args) -> scf::ValueVector { 566 Value val; 567 if (isCOOConstant) 568 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 569 ivs, rank); 570 else 571 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 572 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 573 return {}; 574 }); 575 // Final call to construct sparse tensor storage. 576 params[6] = constantI32(rewriter, loc, kFromCOO); 577 params[7] = ptr; 578 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 579 return success(); 580 } 581 }; 582 583 /// Sparse conversion rule for the release operator. 584 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 585 public: 586 using OpConversionPattern::OpConversionPattern; 587 LogicalResult 588 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 589 ConversionPatternRewriter &rewriter) const override { 590 StringRef name = "delSparseTensor"; 591 TypeRange none; 592 auto fn = getFunc(op, name, none, adaptor.getOperands()); 593 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 594 rewriter.eraseOp(op); 595 return success(); 596 } 597 }; 598 599 /// Sparse conversion rule for pointer accesses. 600 class SparseTensorToPointersConverter 601 : public OpConversionPattern<ToPointersOp> { 602 public: 603 using OpConversionPattern::OpConversionPattern; 604 LogicalResult 605 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 606 ConversionPatternRewriter &rewriter) const override { 607 Type resType = op.getType(); 608 Type eltType = resType.cast<ShapedType>().getElementType(); 609 StringRef name; 610 if (eltType.isIndex()) 611 name = "sparsePointers"; // 64-bit, but its own name for unique signature 612 else if (eltType.isInteger(64)) 613 name = "sparsePointers64"; 614 else if (eltType.isInteger(32)) 615 name = "sparsePointers32"; 616 else if (eltType.isInteger(16)) 617 name = "sparsePointers16"; 618 else if (eltType.isInteger(8)) 619 name = "sparsePointers8"; 620 else 621 return failure(); 622 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 623 /*emitCInterface=*/true); 624 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 625 return success(); 626 } 627 }; 628 629 /// Sparse conversion rule for index accesses. 630 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 631 public: 632 using OpConversionPattern::OpConversionPattern; 633 LogicalResult 634 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 635 ConversionPatternRewriter &rewriter) const override { 636 Type resType = op.getType(); 637 Type eltType = resType.cast<ShapedType>().getElementType(); 638 StringRef name; 639 if (eltType.isIndex()) 640 name = "sparseIndices"; // 64-bit, but its own name for unique signature 641 else if (eltType.isInteger(64)) 642 name = "sparseIndices64"; 643 else if (eltType.isInteger(32)) 644 name = "sparseIndices32"; 645 else if (eltType.isInteger(16)) 646 name = "sparseIndices16"; 647 else if (eltType.isInteger(8)) 648 name = "sparseIndices8"; 649 else 650 return failure(); 651 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 652 /*emitCInterface=*/true); 653 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 654 return success(); 655 } 656 }; 657 658 /// Sparse conversion rule for value accesses. 659 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 660 public: 661 using OpConversionPattern::OpConversionPattern; 662 LogicalResult 663 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 Type resType = op.getType(); 666 Type eltType = resType.cast<ShapedType>().getElementType(); 667 StringRef name; 668 if (eltType.isF64()) 669 name = "sparseValuesF64"; 670 else if (eltType.isF32()) 671 name = "sparseValuesF32"; 672 else if (eltType.isInteger(64)) 673 name = "sparseValuesI64"; 674 else if (eltType.isInteger(32)) 675 name = "sparseValuesI32"; 676 else if (eltType.isInteger(16)) 677 name = "sparseValuesI16"; 678 else if (eltType.isInteger(8)) 679 name = "sparseValuesI8"; 680 else 681 return failure(); 682 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 683 /*emitCInterface=*/true); 684 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 685 return success(); 686 } 687 }; 688 689 /// Sparse conversion rule for tensor reconstruction. 690 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 691 public: 692 using OpConversionPattern::OpConversionPattern; 693 LogicalResult 694 // Simply fold the operator into the pointer to the sparse storage scheme. 695 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 696 ConversionPatternRewriter &rewriter) const override { 697 // Check that all arguments of the tensor reconstruction operators are calls 698 // into the support library that query exactly the same opaque pointer. 699 Value ptr; 700 for (Value op : adaptor.getOperands()) { 701 if (auto call = op.getDefiningOp<CallOp>()) { 702 Value arg = call.getOperand(0); 703 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 704 return failure(); 705 if (!ptr) 706 ptr = arg; 707 else if (arg != ptr) 708 return failure(); 709 } 710 } 711 // If a single opaque pointer is found, perform the folding. 712 if (!ptr) 713 return failure(); 714 rewriter.replaceOp(op, ptr); 715 return success(); 716 } 717 }; 718 719 } // namespace 720 721 //===----------------------------------------------------------------------===// 722 // Public method for populating conversion rules. 723 //===----------------------------------------------------------------------===// 724 725 /// Populates the given patterns list with conversion rules required for 726 /// the sparsification of linear algebra operations. 727 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 728 RewritePatternSet &patterns) { 729 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 730 SparseTensorNewConverter, SparseTensorInitConverter, 731 SparseTensorConvertConverter, SparseTensorReleaseConverter, 732 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 733 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 734 typeConverter, patterns.getContext()); 735 } 736