1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 namespace { 32 33 //===----------------------------------------------------------------------===// 34 // Helper methods. 35 //===----------------------------------------------------------------------===// 36 37 /// Generates a constant zero of the given type. 38 inline static Value constantZero(ConversionPatternRewriter &rewriter, 39 Location loc, Type t) { 40 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 41 } 42 43 /// Generates a constant of `index` type. 44 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 45 Location loc, int64_t i) { 46 return rewriter.create<arith::ConstantIndexOp>(loc, i); 47 } 48 49 /// Generates a constant of `i32` type. 50 inline static Value constantI32(ConversionPatternRewriter &rewriter, 51 Location loc, int32_t i) { 52 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 53 } 54 55 /// Generates a constant of `i8` type. 56 inline static Value constantI8(ConversionPatternRewriter &rewriter, 57 Location loc, int8_t i) { 58 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 59 } 60 61 /// Generates a constant of the given `Action`. 62 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 63 Action action) { 64 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 65 } 66 67 /// Generates a constant of the internal type encoding for overhead storage. 68 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 69 Location loc, unsigned width) { 70 OverheadType sec; 71 switch (width) { 72 default: 73 sec = OverheadType::kU64; 74 break; 75 case 32: 76 sec = OverheadType::kU32; 77 break; 78 case 16: 79 sec = OverheadType::kU16; 80 break; 81 case 8: 82 sec = OverheadType::kU8; 83 break; 84 } 85 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 86 } 87 88 /// Generates a constant of the internal type encoding for primary storage. 89 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 90 Location loc, Type tp) { 91 PrimaryType primary; 92 if (tp.isF64()) 93 primary = PrimaryType::kF64; 94 else if (tp.isF32()) 95 primary = PrimaryType::kF32; 96 else if (tp.isInteger(64)) 97 primary = PrimaryType::kI64; 98 else if (tp.isInteger(32)) 99 primary = PrimaryType::kI32; 100 else if (tp.isInteger(16)) 101 primary = PrimaryType::kI16; 102 else if (tp.isInteger(8)) 103 primary = PrimaryType::kI8; 104 else 105 llvm_unreachable("Unknown element type"); 106 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 107 } 108 109 /// Generates a constant of the internal dimension level type encoding. 110 static Value 111 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 112 SparseTensorEncodingAttr::DimLevelType dlt) { 113 DimLevelType dlt2; 114 switch (dlt) { 115 case SparseTensorEncodingAttr::DimLevelType::Dense: 116 dlt2 = DimLevelType::kDense; 117 break; 118 case SparseTensorEncodingAttr::DimLevelType::Compressed: 119 dlt2 = DimLevelType::kCompressed; 120 break; 121 case SparseTensorEncodingAttr::DimLevelType::Singleton: 122 dlt2 = DimLevelType::kSingleton; 123 break; 124 } 125 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 126 } 127 128 /// Returns a function reference (first hit also inserts into module). Sets 129 /// the "_emit_c_interface" on the function declaration when requested, 130 /// so that LLVM lowering generates a wrapper function that takes care 131 /// of ABI complications with passing in and returning MemRefs to C functions. 132 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 133 TypeRange resultType, ValueRange operands, 134 bool emitCInterface = false) { 135 MLIRContext *context = op->getContext(); 136 auto module = op->getParentOfType<ModuleOp>(); 137 auto result = SymbolRefAttr::get(context, name); 138 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 139 if (!func) { 140 OpBuilder moduleBuilder(module.getBodyRegion()); 141 func = moduleBuilder.create<FuncOp>( 142 op->getLoc(), name, 143 FunctionType::get(context, operands.getTypes(), resultType)); 144 func.setPrivate(); 145 if (emitCInterface) 146 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 147 } 148 return result; 149 } 150 151 /// Generates dimension size call. 152 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 153 SparseTensorEncodingAttr &enc, Value src, 154 int64_t idx) { 155 // Permute the index according to an optional dimension ordering. 156 if (AffineMap p = enc.getDimOrdering()) 157 idx = p.getPermutedPosition(idx); 158 // Generate the call. 159 Location loc = op->getLoc(); 160 StringRef name = "sparseDimSize"; 161 SmallVector<Value, 2> params; 162 params.push_back(src); 163 params.push_back(constantIndex(rewriter, loc, idx)); 164 Type iTp = rewriter.getIndexType(); 165 auto fn = getFunc(op, name, iTp, params); 166 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 167 } 168 169 /// Generates a call into the "swiss army knife" method of the sparse runtime 170 /// support library for materializing sparse tensors into the computation. 171 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 172 ArrayRef<Value> params) { 173 Location loc = op->getLoc(); 174 StringRef name = "newSparseTensor"; 175 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 176 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 177 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 178 return call.getResult(0); 179 } 180 181 /// Populates given sizes array from type. 182 static void sizesFromType(ConversionPatternRewriter &rewriter, 183 SmallVector<Value, 4> &sizes, Location loc, 184 ShapedType stp) { 185 auto shape = stp.getShape(); 186 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 187 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 188 sizes.push_back(constantIndex(rewriter, loc, s)); 189 } 190 } 191 192 /// Populates given sizes array from source. 193 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 194 SmallVector<Value, 4> &sizes, Location loc, 195 Value src) { 196 ShapedType stp = src.getType().cast<ShapedType>(); 197 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 198 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 199 } 200 201 /// Populates given sizes array from type (for static sizes) and from 202 /// an already converted into opague pointer source (for dynamic sizes). 203 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 204 SmallVector<Value, 4> &sizes, Operation *op, 205 SparseTensorEncodingAttr &enc, ShapedType stp, 206 Value src) { 207 auto shape = stp.getShape(); 208 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 209 if (shape[i] == ShapedType::kDynamicSize) 210 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 211 else 212 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 213 } 214 215 /// Generates an uninitialized temporary buffer of the given size and 216 /// type, but returns it as type `memref<? x $tp>` (rather than as type 217 /// `memref<$sz x $tp>`). 218 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 219 unsigned sz, Type tp) { 220 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 221 Value a = constantIndex(rewriter, loc, sz); 222 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 223 } 224 225 /// Generates an uninitialized temporary buffer with room for one value 226 /// of the given type, and returns the `memref<$tp>`. 227 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 228 Type tp) { 229 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 230 } 231 232 /// Generates a temporary buffer of the given type and given contents. 233 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 234 ArrayRef<Value> values) { 235 unsigned sz = values.size(); 236 assert(sz >= 1); 237 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 238 for (unsigned i = 0; i < sz; i++) { 239 Value idx = constantIndex(rewriter, loc, i); 240 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 241 } 242 return buffer; 243 } 244 245 /// Populates parameters required to call the "swiss army knife" method of the 246 /// sparse runtime support library for materializing sparse tensors into the 247 /// computation. 248 static void newParams(ConversionPatternRewriter &rewriter, 249 SmallVector<Value, 8> ¶ms, Operation *op, 250 SparseTensorEncodingAttr &enc, Action action, 251 ValueRange szs, Value ptr = Value()) { 252 Location loc = op->getLoc(); 253 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 254 unsigned sz = dlt.size(); 255 // Sparsity annotations. 256 SmallVector<Value, 4> attrs; 257 for (unsigned i = 0; i < sz; i++) 258 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 259 params.push_back(genBuffer(rewriter, loc, attrs)); 260 // Dimension sizes array of the enveloping tensor. Useful for either 261 // verification of external data, or for construction of internal data. 262 SmallVector<Value, 4> sizes; 263 for (Value s : szs) 264 sizes.push_back(s); 265 params.push_back(genBuffer(rewriter, loc, sizes)); 266 // Dimension order permutation array. This is the "identity" permutation by 267 // default, or otherwise the "reverse" permutation of a given ordering, so 268 // that indices can be mapped quickly to the right position. 269 SmallVector<Value, 4> rev(sz); 270 if (AffineMap p = enc.getDimOrdering()) { 271 for (unsigned i = 0; i < sz; i++) 272 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 273 } else { 274 for (unsigned i = 0; i < sz; i++) 275 rev[i] = constantIndex(rewriter, loc, i); 276 } 277 params.push_back(genBuffer(rewriter, loc, rev)); 278 // Secondary and primary types encoding. 279 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 280 params.push_back( 281 constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth())); 282 params.push_back( 283 constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth())); 284 params.push_back( 285 constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); 286 // User action and pointer. 287 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 288 if (!ptr) 289 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 290 params.push_back(constantAction(rewriter, loc, action)); 291 params.push_back(ptr); 292 } 293 294 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 295 /// For floating types, we use the "unordered" comparator (i.e., returns 296 /// true if `v` is NaN). 297 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 298 Value v) { 299 Type t = v.getType(); 300 Value zero = constantZero(rewriter, loc, t); 301 if (t.isa<FloatType>()) 302 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 303 zero); 304 if (t.isIntOrIndex()) 305 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 306 zero); 307 llvm_unreachable("Unknown element type"); 308 } 309 310 /// Generates the code to read the value from tensor[ivs], and conditionally 311 /// stores the indices ivs to the memory in ind. The generated code looks like 312 /// the following and the insertion point after this routine is inside the 313 /// if-then branch behind the assignment to ind. This is to ensure that the 314 /// addEltX call generated after is inside the if-then branch. 315 /// if (tensor[ivs]!=0) { 316 /// ind = ivs 317 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 318 Location loc, Value tensor, Value ind, 319 ValueRange ivs) { 320 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 321 Value cond = genIsNonzero(rewriter, loc, val); 322 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 323 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 324 unsigned i = 0; 325 for (auto iv : ivs) { 326 Value idx = constantIndex(rewriter, loc, i++); 327 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 328 } 329 return val; 330 } 331 332 /// Generates a call that adds one element to a coordinate scheme. 333 /// In particular, this generates code like the following: 334 /// val = a[i1,..,ik]; 335 /// if val != 0 336 /// t->add(val, [i1,..,ik], [p1,..,pk]); 337 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 338 Type eltType, Value ptr, Value val, Value ind, 339 Value perm) { 340 Location loc = op->getLoc(); 341 StringRef name; 342 if (eltType.isF64()) 343 name = "addEltF64"; 344 else if (eltType.isF32()) 345 name = "addEltF32"; 346 else if (eltType.isInteger(64)) 347 name = "addEltI64"; 348 else if (eltType.isInteger(32)) 349 name = "addEltI32"; 350 else if (eltType.isInteger(16)) 351 name = "addEltI16"; 352 else if (eltType.isInteger(8)) 353 name = "addEltI8"; 354 else 355 llvm_unreachable("Unknown element type"); 356 SmallVector<Value, 8> params; 357 params.push_back(ptr); 358 params.push_back(val); 359 params.push_back(ind); 360 params.push_back(perm); 361 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 362 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 363 rewriter.create<CallOp>(loc, pTp, fn, params); 364 } 365 366 /// Generates a call to `iter->getNext()`. If there is a next element, 367 /// then it is copied into the out-parameters `ind` and `elemPtr`, 368 /// and the return value is true. If there isn't a next element, then 369 /// the memory for `iter` is freed and the return value is false. 370 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 371 Value iter, Value ind, Value elemPtr) { 372 Location loc = op->getLoc(); 373 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 374 StringRef name; 375 if (elemTp.isF64()) 376 name = "getNextF64"; 377 else if (elemTp.isF32()) 378 name = "getNextF32"; 379 else if (elemTp.isInteger(64)) 380 name = "getNextI64"; 381 else if (elemTp.isInteger(32)) 382 name = "getNextI32"; 383 else if (elemTp.isInteger(16)) 384 name = "getNextI16"; 385 else if (elemTp.isInteger(8)) 386 name = "getNextI8"; 387 else 388 llvm_unreachable("Unknown element type"); 389 SmallVector<Value, 3> params; 390 params.push_back(iter); 391 params.push_back(ind); 392 params.push_back(elemPtr); 393 Type i1 = rewriter.getI1Type(); 394 auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); 395 auto call = rewriter.create<CallOp>(loc, i1, fn, params); 396 return call.getResult(0); 397 } 398 399 /// If the tensor is a sparse constant, generates and returns the pair of 400 /// the constants for the indices and the values. 401 static Optional<std::pair<Value, Value>> 402 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 403 Value tensor) { 404 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 405 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 406 DenseElementsAttr indicesAttr = attr.getIndices(); 407 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 408 DenseElementsAttr valuesAttr = attr.getValues(); 409 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 410 return std::make_pair(indices, values); 411 } 412 } 413 return {}; 414 } 415 416 /// Generates the code to copy the index at indices[ivs] to ind, and return 417 /// the value at value[ivs]. 418 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 419 Location loc, Value indices, 420 Value values, Value ind, ValueRange ivs, 421 unsigned rank) { 422 for (unsigned i = 0; i < rank; i++) { 423 Value idx = constantIndex(rewriter, loc, i); 424 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 425 ValueRange{ivs[0], idx}); 426 val = 427 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 428 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 429 } 430 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 431 } 432 433 /// Generates code to allocate a tensor of the given type, and zero 434 /// initialize it. If the tensor type has any dynamic sizes, then the 435 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 436 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 437 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 438 RankedTensorType tensorTp, ValueRange sizes) { 439 Type elemTp = tensorTp.getElementType(); 440 auto shape = tensorTp.getShape(); 441 auto memTp = MemRefType::get(shape, elemTp); 442 SmallVector<Value> dynamicSizes; 443 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 444 if (shape[i] == ShapedType::kDynamicSize) 445 dynamicSizes.push_back(sizes[i]); 446 } 447 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 448 Value zero = constantZero(rewriter, loc, elemTp); 449 rewriter.create<linalg::FillOp>(loc, zero, mem).result(); 450 return mem; 451 } 452 453 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 454 /// the tensor created by allocDenseTensor(). The `rank` is the rank 455 /// of the `tensor` and the length of `ind`. 456 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 457 Location loc, Value elemPtr, 458 Value tensor, unsigned rank, 459 Value ind) { 460 SmallVector<Value, 4> ivs; 461 ivs.reserve(rank); 462 for (unsigned i = 0; i < rank; i++) { 463 Value idx = constantIndex(rewriter, loc, i); 464 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 465 } 466 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 467 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // Conversion rules. 472 //===----------------------------------------------------------------------===// 473 474 /// Sparse conversion rule for returns. 475 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 476 public: 477 using OpConversionPattern::OpConversionPattern; 478 LogicalResult 479 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 480 ConversionPatternRewriter &rewriter) const override { 481 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 482 return success(); 483 } 484 }; 485 486 /// Sparse conversion rule for dimension accesses. 487 class SparseTensorToDimSizeConverter 488 : public OpConversionPattern<tensor::DimOp> { 489 public: 490 using OpConversionPattern::OpConversionPattern; 491 LogicalResult 492 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 493 ConversionPatternRewriter &rewriter) const override { 494 // Only rewrite annotated DimOp with constant index. 495 auto enc = getSparseTensorEncoding(op.source().getType()); 496 if (!enc) 497 return failure(); 498 Optional<int64_t> index = op.getConstantIndex(); 499 if (!index.hasValue()) 500 return failure(); 501 // Generate the call. 502 Value src = adaptor.getOperands()[0]; 503 int64_t idx = index.getValue(); 504 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 505 return success(); 506 } 507 }; 508 509 /// Sparse conversion rule for trivial tensor casts. 510 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 511 using OpConversionPattern::OpConversionPattern; 512 LogicalResult 513 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 514 ConversionPatternRewriter &rewriter) const override { 515 // Only rewrite identically annotated source/dest. 516 auto encDst = getSparseTensorEncoding(op.getType()); 517 auto encSrc = getSparseTensorEncoding(op.source().getType()); 518 if (!encDst || encDst != encSrc) 519 return failure(); 520 rewriter.replaceOp(op, adaptor.getOperands()); 521 return success(); 522 } 523 }; 524 525 /// Sparse conversion rule for the new operator. 526 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 527 using OpConversionPattern::OpConversionPattern; 528 LogicalResult 529 matchAndRewrite(NewOp op, OpAdaptor adaptor, 530 ConversionPatternRewriter &rewriter) const override { 531 Type resType = op.getType(); 532 auto enc = getSparseTensorEncoding(resType); 533 if (!enc) 534 return failure(); 535 // Generate the call to construct tensor from ptr. The sizes are 536 // inferred from the result type of the new operator. 537 SmallVector<Value, 4> sizes; 538 SmallVector<Value, 8> params; 539 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 540 Value ptr = adaptor.getOperands()[0]; 541 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 542 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 543 return success(); 544 } 545 }; 546 547 /// Sparse conversion rule for the init operator. 548 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 549 using OpConversionPattern::OpConversionPattern; 550 LogicalResult 551 matchAndRewrite(InitOp op, OpAdaptor adaptor, 552 ConversionPatternRewriter &rewriter) const override { 553 Type resType = op.getType(); 554 auto enc = getSparseTensorEncoding(resType); 555 if (!enc) 556 return failure(); 557 // Generate the call to construct empty tensor. The sizes are 558 // explicitly defined by the arguments to the init operator. 559 SmallVector<Value, 8> params; 560 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 561 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 562 return success(); 563 } 564 }; 565 566 /// Sparse conversion rule for the convert operator. 567 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 568 using OpConversionPattern::OpConversionPattern; 569 LogicalResult 570 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 571 ConversionPatternRewriter &rewriter) const override { 572 Location loc = op->getLoc(); 573 Type resType = op.getType(); 574 Type srcType = op.source().getType(); 575 auto encDst = getSparseTensorEncoding(resType); 576 auto encSrc = getSparseTensorEncoding(srcType); 577 Value src = adaptor.getOperands()[0]; 578 if (encDst && encSrc) { 579 // This is a sparse => sparse conversion, which is handled as follows: 580 // t = src->toCOO(); ; src to COO in dst order 581 // dst = newSparseTensor(t) 582 // Using the coordinate scheme as an intermediate does not always 583 // yield the fastest conversion but avoids the need for a full 584 // O(N^2) conversion matrix. 585 if (encDst == encSrc) { 586 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 587 return success(); 588 } 589 SmallVector<Value, 4> sizes; 590 SmallVector<Value, 8> params; 591 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 592 src); 593 // Set up encoding with right mix of src and dst so that the two 594 // method calls can share most parameters, while still providing 595 // the correct sparsity information to either of them. 596 auto enc = SparseTensorEncodingAttr::get( 597 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 598 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 599 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 600 Value coo = genNewCall(rewriter, op, params); 601 params[3] = constantOverheadTypeEncoding(rewriter, loc, 602 encDst.getPointerBitWidth()); 603 params[4] = constantOverheadTypeEncoding(rewriter, loc, 604 encDst.getIndexBitWidth()); 605 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 606 params[7] = coo; 607 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 608 return success(); 609 } 610 if (!encDst && encSrc) { 611 // This is sparse => dense conversion, which is handled as follows: 612 // dst = new Tensor(0); 613 // iter = src->toCOO(); 614 // iter->startIterator(); 615 // while (elem = iter->getNext()) { 616 // dst[elem.indices] = elem.value; 617 // } 618 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 619 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 620 unsigned rank = dstTensorTp.getRank(); 621 Type elemTp = dstTensorTp.getElementType(); 622 // Fabricate a no-permutation encoding for newParams(). 623 // The pointer/index types must be those of `src`. 624 // The dimLevelTypes aren't actually used by Action::kToIterator. 625 encDst = SparseTensorEncodingAttr::get( 626 op->getContext(), 627 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 628 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 629 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 630 SmallVector<Value, 4> sizes; 631 SmallVector<Value, 8> params; 632 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 633 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 634 Value iter = genNewCall(rewriter, op, params); 635 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 636 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 637 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 638 SmallVector<Value> noArgs; 639 SmallVector<Type> noTypes; 640 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 641 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 642 rewriter.setInsertionPointToEnd(before); 643 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 644 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 645 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 646 rewriter.setInsertionPointToStart(after); 647 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 648 rewriter.create<scf::YieldOp>(loc); 649 rewriter.setInsertionPointAfter(whileOp); 650 rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, dst); 651 return success(); 652 } 653 if (!encDst && !encSrc) { 654 // dense => dense 655 return failure(); 656 } 657 // This is a dense => sparse conversion or a sparse constant in COO => 658 // sparse conversion, which is handled as follows: 659 // t = newSparseCOO() 660 // ...code to fill the COO tensor t... 661 // s = newSparseTensor(t) 662 // 663 // To fill the COO tensor from a dense tensor: 664 // for i1 in dim1 665 // .. 666 // for ik in dimk 667 // val = a[i1,..,ik] 668 // if val != 0 669 // t->add(val, [i1,..,ik], [p1,..,pk]) 670 // 671 // To fill the COO tensor from a sparse constant in COO format: 672 // for i in range(NNZ) 673 // val = values[i] 674 // [i1,..,ik] = indices[i] 675 // t->add(val, [i1,..,ik], [p1,..,pk]) 676 // 677 // Note that the dense tensor traversal code is actually implemented 678 // using MLIR IR to avoid having to expose too much low-level 679 // memref traversal details to the runtime support library. 680 // Also note that the code below only generates the "new" ops and 681 // the loop-nest per se; whereas the entire body of the innermost 682 // loop is generated by genAddElt(). 683 ShapedType stp = resType.cast<ShapedType>(); 684 unsigned rank = stp.getRank(); 685 SmallVector<Value, 4> sizes; 686 SmallVector<Value, 8> params; 687 sizesFromSrc(rewriter, sizes, loc, src); 688 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 689 Value ptr = genNewCall(rewriter, op, params); 690 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 691 Value perm = params[2]; 692 SmallVector<Value> lo; 693 SmallVector<Value> hi; 694 SmallVector<Value> st; 695 Value zero = constantIndex(rewriter, loc, 0); 696 Value one = constantIndex(rewriter, loc, 1); 697 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 698 bool isCOOConstant = indicesValues.hasValue(); 699 Value indices; 700 Value values; 701 if (isCOOConstant) { 702 indices = indicesValues->first; 703 values = indicesValues->second; 704 lo.push_back(zero); 705 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 706 st.push_back(one); 707 } else { 708 for (unsigned i = 0; i < rank; i++) { 709 lo.push_back(zero); 710 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 711 st.push_back(one); 712 } 713 } 714 Type eltType = stp.getElementType(); 715 scf::buildLoopNest( 716 rewriter, op.getLoc(), lo, hi, st, {}, 717 [&](OpBuilder &builder, Location loc, ValueRange ivs, 718 ValueRange args) -> scf::ValueVector { 719 Value val; 720 if (isCOOConstant) 721 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 722 ivs, rank); 723 else 724 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 725 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 726 return {}; 727 }); 728 // Final call to construct sparse tensor storage. 729 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 730 params[7] = ptr; 731 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 732 return success(); 733 } 734 }; 735 736 /// Sparse conversion rule for the release operator. 737 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 738 public: 739 using OpConversionPattern::OpConversionPattern; 740 LogicalResult 741 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 742 ConversionPatternRewriter &rewriter) const override { 743 StringRef name = "delSparseTensor"; 744 TypeRange none; 745 auto fn = getFunc(op, name, none, adaptor.getOperands()); 746 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 747 rewriter.eraseOp(op); 748 return success(); 749 } 750 }; 751 752 /// Sparse conversion rule for pointer accesses. 753 class SparseTensorToPointersConverter 754 : public OpConversionPattern<ToPointersOp> { 755 public: 756 using OpConversionPattern::OpConversionPattern; 757 LogicalResult 758 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 759 ConversionPatternRewriter &rewriter) const override { 760 Type resType = op.getType(); 761 Type eltType = resType.cast<ShapedType>().getElementType(); 762 StringRef name; 763 if (eltType.isIndex()) 764 name = "sparsePointers"; 765 else if (eltType.isInteger(64)) 766 name = "sparsePointers64"; 767 else if (eltType.isInteger(32)) 768 name = "sparsePointers32"; 769 else if (eltType.isInteger(16)) 770 name = "sparsePointers16"; 771 else if (eltType.isInteger(8)) 772 name = "sparsePointers8"; 773 else 774 return failure(); 775 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 776 /*emitCInterface=*/true); 777 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 778 return success(); 779 } 780 }; 781 782 /// Sparse conversion rule for index accesses. 783 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 784 public: 785 using OpConversionPattern::OpConversionPattern; 786 LogicalResult 787 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 788 ConversionPatternRewriter &rewriter) const override { 789 Type resType = op.getType(); 790 Type eltType = resType.cast<ShapedType>().getElementType(); 791 StringRef name; 792 if (eltType.isIndex()) 793 name = "sparseIndices"; 794 else if (eltType.isInteger(64)) 795 name = "sparseIndices64"; 796 else if (eltType.isInteger(32)) 797 name = "sparseIndices32"; 798 else if (eltType.isInteger(16)) 799 name = "sparseIndices16"; 800 else if (eltType.isInteger(8)) 801 name = "sparseIndices8"; 802 else 803 return failure(); 804 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 805 /*emitCInterface=*/true); 806 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 807 return success(); 808 } 809 }; 810 811 /// Sparse conversion rule for value accesses. 812 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 813 public: 814 using OpConversionPattern::OpConversionPattern; 815 LogicalResult 816 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 817 ConversionPatternRewriter &rewriter) const override { 818 Type resType = op.getType(); 819 Type eltType = resType.cast<ShapedType>().getElementType(); 820 StringRef name; 821 if (eltType.isF64()) 822 name = "sparseValuesF64"; 823 else if (eltType.isF32()) 824 name = "sparseValuesF32"; 825 else if (eltType.isInteger(64)) 826 name = "sparseValuesI64"; 827 else if (eltType.isInteger(32)) 828 name = "sparseValuesI32"; 829 else if (eltType.isInteger(16)) 830 name = "sparseValuesI16"; 831 else if (eltType.isInteger(8)) 832 name = "sparseValuesI8"; 833 else 834 return failure(); 835 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 836 /*emitCInterface=*/true); 837 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 838 return success(); 839 } 840 }; 841 842 /// Sparse conversion rule for tensor reconstruction. 843 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 844 public: 845 using OpConversionPattern::OpConversionPattern; 846 LogicalResult 847 // Simply fold the operator into the pointer to the sparse storage scheme. 848 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 849 ConversionPatternRewriter &rewriter) const override { 850 // Check that all arguments of the tensor reconstruction operators are calls 851 // into the support library that query exactly the same opaque pointer. 852 Value ptr; 853 for (Value op : adaptor.getOperands()) { 854 if (auto call = op.getDefiningOp<CallOp>()) { 855 Value arg = call.getOperand(0); 856 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 857 return failure(); 858 if (!ptr) 859 ptr = arg; 860 else if (arg != ptr) 861 return failure(); 862 } 863 } 864 // If a single opaque pointer is found, perform the folding. 865 if (!ptr) 866 return failure(); 867 rewriter.replaceOp(op, ptr); 868 return success(); 869 } 870 }; 871 872 } // namespace 873 874 //===----------------------------------------------------------------------===// 875 // Public method for populating conversion rules. 876 //===----------------------------------------------------------------------===// 877 878 /// Populates the given patterns list with conversion rules required for 879 /// the sparsification of linear algebra operations. 880 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 881 RewritePatternSet &patterns) { 882 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 883 SparseCastConverter, SparseTensorNewConverter, 884 SparseTensorInitConverter, SparseTensorConvertConverter, 885 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 886 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 887 SparseTensorToTensorConverter>(typeConverter, 888 patterns.getContext()); 889 } 890