1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 28 using namespace mlir; 29 using namespace mlir::sparse_tensor; 30 31 namespace { 32 33 //===----------------------------------------------------------------------===// 34 // Helper methods. 35 //===----------------------------------------------------------------------===// 36 37 /// Generates a constant zero of the given type. 38 inline static Value constantZero(ConversionPatternRewriter &rewriter, 39 Location loc, Type t) { 40 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 41 } 42 43 /// Generates a constant of `index` type. 44 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 45 Location loc, int64_t i) { 46 return rewriter.create<arith::ConstantIndexOp>(loc, i); 47 } 48 49 /// Generates a constant of `i32` type. 50 inline static Value constantI32(ConversionPatternRewriter &rewriter, 51 Location loc, int32_t i) { 52 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 53 } 54 55 /// Generates a constant of `i8` type. 56 inline static Value constantI8(ConversionPatternRewriter &rewriter, 57 Location loc, int8_t i) { 58 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 59 } 60 61 /// Generates a constant of the given `Action`. 62 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 63 Action action) { 64 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 65 } 66 67 /// Generates a constant of the internal type encoding for overhead storage. 68 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 69 Location loc, unsigned width) { 70 OverheadType sec; 71 switch (width) { 72 default: 73 sec = OverheadType::kU64; 74 break; 75 case 32: 76 sec = OverheadType::kU32; 77 break; 78 case 16: 79 sec = OverheadType::kU16; 80 break; 81 case 8: 82 sec = OverheadType::kU8; 83 break; 84 } 85 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 86 } 87 88 /// Generates a constant of the internal type encoding for pointer 89 /// overhead storage. 90 static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, 91 Location loc, 92 SparseTensorEncodingAttr &enc) { 93 return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); 94 } 95 96 /// Generates a constant of the internal type encoding for index overhead 97 /// storage. 98 static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, 99 Location loc, 100 SparseTensorEncodingAttr &enc) { 101 return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); 102 } 103 104 /// Generates a constant of the internal type encoding for primary storage. 105 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 106 Location loc, Type tp) { 107 PrimaryType primary; 108 if (tp.isF64()) 109 primary = PrimaryType::kF64; 110 else if (tp.isF32()) 111 primary = PrimaryType::kF32; 112 else if (tp.isInteger(64)) 113 primary = PrimaryType::kI64; 114 else if (tp.isInteger(32)) 115 primary = PrimaryType::kI32; 116 else if (tp.isInteger(16)) 117 primary = PrimaryType::kI16; 118 else if (tp.isInteger(8)) 119 primary = PrimaryType::kI8; 120 else 121 llvm_unreachable("Unknown element type"); 122 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 123 } 124 125 /// Generates a constant of the internal dimension level type encoding. 126 static Value 127 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 128 SparseTensorEncodingAttr::DimLevelType dlt) { 129 DimLevelType dlt2; 130 switch (dlt) { 131 case SparseTensorEncodingAttr::DimLevelType::Dense: 132 dlt2 = DimLevelType::kDense; 133 break; 134 case SparseTensorEncodingAttr::DimLevelType::Compressed: 135 dlt2 = DimLevelType::kCompressed; 136 break; 137 case SparseTensorEncodingAttr::DimLevelType::Singleton: 138 dlt2 = DimLevelType::kSingleton; 139 break; 140 } 141 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 142 } 143 144 /// Returns a function reference (first hit also inserts into module). Sets 145 /// the "_emit_c_interface" on the function declaration when requested, 146 /// so that LLVM lowering generates a wrapper function that takes care 147 /// of ABI complications with passing in and returning MemRefs to C functions. 148 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 149 TypeRange resultType, ValueRange operands, 150 bool emitCInterface = false) { 151 MLIRContext *context = op->getContext(); 152 auto module = op->getParentOfType<ModuleOp>(); 153 auto result = SymbolRefAttr::get(context, name); 154 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 155 if (!func) { 156 OpBuilder moduleBuilder(module.getBodyRegion()); 157 func = moduleBuilder.create<FuncOp>( 158 op->getLoc(), name, 159 FunctionType::get(context, operands.getTypes(), resultType)); 160 func.setPrivate(); 161 if (emitCInterface) 162 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 163 } 164 return result; 165 } 166 167 /// Generates dimension size call. 168 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 169 SparseTensorEncodingAttr &enc, Value src, 170 int64_t idx) { 171 // Permute the index according to an optional dimension ordering. 172 if (AffineMap p = enc.getDimOrdering()) 173 idx = p.getPermutedPosition(idx); 174 // Generate the call. 175 Location loc = op->getLoc(); 176 StringRef name = "sparseDimSize"; 177 SmallVector<Value, 2> params; 178 params.push_back(src); 179 params.push_back(constantIndex(rewriter, loc, idx)); 180 Type iTp = rewriter.getIndexType(); 181 auto fn = getFunc(op, name, iTp, params); 182 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 183 } 184 185 /// Generates a call into the "swiss army knife" method of the sparse runtime 186 /// support library for materializing sparse tensors into the computation. 187 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 188 ArrayRef<Value> params) { 189 Location loc = op->getLoc(); 190 StringRef name = "newSparseTensor"; 191 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 192 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 193 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 194 return call.getResult(0); 195 } 196 197 /// Populates given sizes array from type. 198 static void sizesFromType(ConversionPatternRewriter &rewriter, 199 SmallVector<Value, 4> &sizes, Location loc, 200 ShapedType stp) { 201 auto shape = stp.getShape(); 202 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 203 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 204 sizes.push_back(constantIndex(rewriter, loc, s)); 205 } 206 } 207 208 /// Populates given sizes array from source. 209 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 210 SmallVector<Value, 4> &sizes, Location loc, 211 Value src) { 212 ShapedType stp = src.getType().cast<ShapedType>(); 213 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 214 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 215 } 216 217 /// Populates given sizes array from type (for static sizes) and from 218 /// an already converted into opague pointer source (for dynamic sizes). 219 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 220 SmallVector<Value, 4> &sizes, Operation *op, 221 SparseTensorEncodingAttr &enc, ShapedType stp, 222 Value src) { 223 auto shape = stp.getShape(); 224 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 225 if (shape[i] == ShapedType::kDynamicSize) 226 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 227 else 228 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 229 } 230 231 /// Generates an uninitialized temporary buffer of the given size and 232 /// type, but returns it as type `memref<? x $tp>` (rather than as type 233 /// `memref<$sz x $tp>`). 234 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 235 unsigned sz, Type tp) { 236 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 237 Value a = constantIndex(rewriter, loc, sz); 238 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 239 } 240 241 /// Generates an uninitialized temporary buffer with room for one value 242 /// of the given type, and returns the `memref<$tp>`. 243 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 244 Type tp) { 245 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 246 } 247 248 /// Generates a temporary buffer of the given type and given contents. 249 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 250 ArrayRef<Value> values) { 251 unsigned sz = values.size(); 252 assert(sz >= 1); 253 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 254 for (unsigned i = 0; i < sz; i++) { 255 Value idx = constantIndex(rewriter, loc, i); 256 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 257 } 258 return buffer; 259 } 260 261 /// Populates parameters required to call the "swiss army knife" method of the 262 /// sparse runtime support library for materializing sparse tensors into the 263 /// computation. 264 static void newParams(ConversionPatternRewriter &rewriter, 265 SmallVector<Value, 8> ¶ms, Operation *op, 266 SparseTensorEncodingAttr &enc, Action action, 267 ValueRange szs, Value ptr = Value()) { 268 Location loc = op->getLoc(); 269 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 270 unsigned sz = dlt.size(); 271 // Sparsity annotations. 272 SmallVector<Value, 4> attrs; 273 for (unsigned i = 0; i < sz; i++) 274 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 275 params.push_back(genBuffer(rewriter, loc, attrs)); 276 // Dimension sizes array of the enveloping tensor. Useful for either 277 // verification of external data, or for construction of internal data. 278 SmallVector<Value, 4> sizes; 279 for (Value s : szs) 280 sizes.push_back(s); 281 params.push_back(genBuffer(rewriter, loc, sizes)); 282 // Dimension order permutation array. This is the "identity" permutation by 283 // default, or otherwise the "reverse" permutation of a given ordering, so 284 // that indices can be mapped quickly to the right position. 285 SmallVector<Value, 4> rev(sz); 286 if (AffineMap p = enc.getDimOrdering()) { 287 for (unsigned i = 0; i < sz; i++) 288 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 289 } else { 290 for (unsigned i = 0; i < sz; i++) 291 rev[i] = constantIndex(rewriter, loc, i); 292 } 293 params.push_back(genBuffer(rewriter, loc, rev)); 294 // Secondary and primary types encoding. 295 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 296 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 297 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 298 params.push_back( 299 constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); 300 // User action and pointer. 301 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 302 if (!ptr) 303 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 304 params.push_back(constantAction(rewriter, loc, action)); 305 params.push_back(ptr); 306 } 307 308 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 309 /// For floating types, we use the "unordered" comparator (i.e., returns 310 /// true if `v` is NaN). 311 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 312 Value v) { 313 Type t = v.getType(); 314 Value zero = constantZero(rewriter, loc, t); 315 if (t.isa<FloatType>()) 316 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 317 zero); 318 if (t.isIntOrIndex()) 319 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 320 zero); 321 llvm_unreachable("Unknown element type"); 322 } 323 324 /// Generates the code to read the value from tensor[ivs], and conditionally 325 /// stores the indices ivs to the memory in ind. The generated code looks like 326 /// the following and the insertion point after this routine is inside the 327 /// if-then branch behind the assignment to ind. This is to ensure that the 328 /// addEltX call generated after is inside the if-then branch. 329 /// if (tensor[ivs]!=0) { 330 /// ind = ivs 331 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 332 Location loc, Value tensor, Value ind, 333 ValueRange ivs) { 334 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 335 Value cond = genIsNonzero(rewriter, loc, val); 336 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 337 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 338 unsigned i = 0; 339 for (auto iv : ivs) { 340 Value idx = constantIndex(rewriter, loc, i++); 341 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 342 } 343 return val; 344 } 345 346 /// Generates a call that adds one element to a coordinate scheme. 347 /// In particular, this generates code like the following: 348 /// val = a[i1,..,ik]; 349 /// if val != 0 350 /// t->add(val, [i1,..,ik], [p1,..,pk]); 351 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 352 Type eltType, Value ptr, Value val, Value ind, 353 Value perm) { 354 Location loc = op->getLoc(); 355 StringRef name; 356 if (eltType.isF64()) 357 name = "addEltF64"; 358 else if (eltType.isF32()) 359 name = "addEltF32"; 360 else if (eltType.isInteger(64)) 361 name = "addEltI64"; 362 else if (eltType.isInteger(32)) 363 name = "addEltI32"; 364 else if (eltType.isInteger(16)) 365 name = "addEltI16"; 366 else if (eltType.isInteger(8)) 367 name = "addEltI8"; 368 else 369 llvm_unreachable("Unknown element type"); 370 SmallVector<Value, 8> params; 371 params.push_back(ptr); 372 params.push_back(val); 373 params.push_back(ind); 374 params.push_back(perm); 375 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 376 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 377 rewriter.create<CallOp>(loc, pTp, fn, params); 378 } 379 380 /// Generates a call to `iter->getNext()`. If there is a next element, 381 /// then it is copied into the out-parameters `ind` and `elemPtr`, 382 /// and the return value is true. If there isn't a next element, then 383 /// the memory for `iter` is freed and the return value is false. 384 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 385 Value iter, Value ind, Value elemPtr) { 386 Location loc = op->getLoc(); 387 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 388 StringRef name; 389 if (elemTp.isF64()) 390 name = "getNextF64"; 391 else if (elemTp.isF32()) 392 name = "getNextF32"; 393 else if (elemTp.isInteger(64)) 394 name = "getNextI64"; 395 else if (elemTp.isInteger(32)) 396 name = "getNextI32"; 397 else if (elemTp.isInteger(16)) 398 name = "getNextI16"; 399 else if (elemTp.isInteger(8)) 400 name = "getNextI8"; 401 else 402 llvm_unreachable("Unknown element type"); 403 SmallVector<Value, 3> params; 404 params.push_back(iter); 405 params.push_back(ind); 406 params.push_back(elemPtr); 407 Type i1 = rewriter.getI1Type(); 408 auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); 409 auto call = rewriter.create<CallOp>(loc, i1, fn, params); 410 return call.getResult(0); 411 } 412 413 /// If the tensor is a sparse constant, generates and returns the pair of 414 /// the constants for the indices and the values. 415 static Optional<std::pair<Value, Value>> 416 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 417 Value tensor) { 418 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 419 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 420 DenseElementsAttr indicesAttr = attr.getIndices(); 421 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 422 DenseElementsAttr valuesAttr = attr.getValues(); 423 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 424 return std::make_pair(indices, values); 425 } 426 } 427 return {}; 428 } 429 430 /// Generates the code to copy the index at indices[ivs] to ind, and return 431 /// the value at value[ivs]. 432 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 433 Location loc, Value indices, 434 Value values, Value ind, ValueRange ivs, 435 unsigned rank) { 436 for (unsigned i = 0; i < rank; i++) { 437 Value idx = constantIndex(rewriter, loc, i); 438 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 439 ValueRange{ivs[0], idx}); 440 val = 441 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 442 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 443 } 444 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 445 } 446 447 /// Generates code to allocate a tensor of the given type, and zero 448 /// initialize it. If the tensor type has any dynamic sizes, then the 449 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 450 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 451 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 452 RankedTensorType tensorTp, ValueRange sizes) { 453 Type elemTp = tensorTp.getElementType(); 454 auto shape = tensorTp.getShape(); 455 auto memTp = MemRefType::get(shape, elemTp); 456 SmallVector<Value> dynamicSizes; 457 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 458 if (shape[i] == ShapedType::kDynamicSize) 459 dynamicSizes.push_back(sizes[i]); 460 } 461 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 462 Value zero = constantZero(rewriter, loc, elemTp); 463 rewriter.create<linalg::FillOp>(loc, zero, mem).result(); 464 return mem; 465 } 466 467 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 468 /// the tensor created by allocDenseTensor(). The `rank` is the rank 469 /// of the `tensor` and the length of `ind`. 470 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 471 Location loc, Value elemPtr, 472 Value tensor, unsigned rank, 473 Value ind) { 474 SmallVector<Value, 4> ivs; 475 ivs.reserve(rank); 476 for (unsigned i = 0; i < rank; i++) { 477 Value idx = constantIndex(rewriter, loc, i); 478 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 479 } 480 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 481 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 482 } 483 484 //===----------------------------------------------------------------------===// 485 // Conversion rules. 486 //===----------------------------------------------------------------------===// 487 488 /// Sparse conversion rule for returns. 489 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 490 public: 491 using OpConversionPattern::OpConversionPattern; 492 LogicalResult 493 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 494 ConversionPatternRewriter &rewriter) const override { 495 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 496 return success(); 497 } 498 }; 499 500 /// Sparse conversion rule for dimension accesses. 501 class SparseTensorToDimSizeConverter 502 : public OpConversionPattern<tensor::DimOp> { 503 public: 504 using OpConversionPattern::OpConversionPattern; 505 LogicalResult 506 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 507 ConversionPatternRewriter &rewriter) const override { 508 // Only rewrite annotated DimOp with constant index. 509 auto enc = getSparseTensorEncoding(op.source().getType()); 510 if (!enc) 511 return failure(); 512 Optional<int64_t> index = op.getConstantIndex(); 513 if (!index.hasValue()) 514 return failure(); 515 // Generate the call. 516 Value src = adaptor.getOperands()[0]; 517 int64_t idx = index.getValue(); 518 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 519 return success(); 520 } 521 }; 522 523 /// Sparse conversion rule for trivial tensor casts. 524 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 525 using OpConversionPattern::OpConversionPattern; 526 LogicalResult 527 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 528 ConversionPatternRewriter &rewriter) const override { 529 // Only rewrite identically annotated source/dest. 530 auto encDst = getSparseTensorEncoding(op.getType()); 531 auto encSrc = getSparseTensorEncoding(op.source().getType()); 532 if (!encDst || encDst != encSrc) 533 return failure(); 534 rewriter.replaceOp(op, adaptor.getOperands()); 535 return success(); 536 } 537 }; 538 539 /// Sparse conversion rule for the new operator. 540 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 541 using OpConversionPattern::OpConversionPattern; 542 LogicalResult 543 matchAndRewrite(NewOp op, OpAdaptor adaptor, 544 ConversionPatternRewriter &rewriter) const override { 545 Type resType = op.getType(); 546 auto enc = getSparseTensorEncoding(resType); 547 if (!enc) 548 return failure(); 549 // Generate the call to construct tensor from ptr. The sizes are 550 // inferred from the result type of the new operator. 551 SmallVector<Value, 4> sizes; 552 SmallVector<Value, 8> params; 553 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 554 Value ptr = adaptor.getOperands()[0]; 555 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 556 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 557 return success(); 558 } 559 }; 560 561 /// Sparse conversion rule for the init operator. 562 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 563 using OpConversionPattern::OpConversionPattern; 564 LogicalResult 565 matchAndRewrite(InitOp op, OpAdaptor adaptor, 566 ConversionPatternRewriter &rewriter) const override { 567 Type resType = op.getType(); 568 auto enc = getSparseTensorEncoding(resType); 569 if (!enc) 570 return failure(); 571 // Generate the call to construct empty tensor. The sizes are 572 // explicitly defined by the arguments to the init operator. 573 SmallVector<Value, 8> params; 574 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 575 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 576 return success(); 577 } 578 }; 579 580 /// Sparse conversion rule for the convert operator. 581 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 582 using OpConversionPattern::OpConversionPattern; 583 LogicalResult 584 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 585 ConversionPatternRewriter &rewriter) const override { 586 Location loc = op->getLoc(); 587 Type resType = op.getType(); 588 Type srcType = op.source().getType(); 589 auto encDst = getSparseTensorEncoding(resType); 590 auto encSrc = getSparseTensorEncoding(srcType); 591 Value src = adaptor.getOperands()[0]; 592 if (encDst && encSrc) { 593 // This is a sparse => sparse conversion, which is handled as follows: 594 // t = src->toCOO(); ; src to COO in dst order 595 // dst = newSparseTensor(t) 596 // Using the coordinate scheme as an intermediate does not always 597 // yield the fastest conversion but avoids the need for a full 598 // O(N^2) conversion matrix. 599 if (encDst == encSrc) { 600 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 601 return success(); 602 } 603 SmallVector<Value, 4> sizes; 604 SmallVector<Value, 8> params; 605 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 606 src); 607 // Set up encoding with right mix of src and dst so that the two 608 // method calls can share most parameters, while still providing 609 // the correct sparsity information to either of them. 610 auto enc = SparseTensorEncodingAttr::get( 611 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 612 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 613 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 614 Value coo = genNewCall(rewriter, op, params); 615 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 616 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 617 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 618 params[7] = coo; 619 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 620 return success(); 621 } 622 if (!encDst && encSrc) { 623 // This is sparse => dense conversion, which is handled as follows: 624 // dst = new Tensor(0); 625 // iter = src->toCOO(); 626 // iter->startIterator(); 627 // while (elem = iter->getNext()) { 628 // dst[elem.indices] = elem.value; 629 // } 630 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 631 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 632 unsigned rank = dstTensorTp.getRank(); 633 Type elemTp = dstTensorTp.getElementType(); 634 // Fabricate a no-permutation encoding for newParams(). 635 // The pointer/index types must be those of `src`. 636 // The dimLevelTypes aren't actually used by Action::kToIterator. 637 encDst = SparseTensorEncodingAttr::get( 638 op->getContext(), 639 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 640 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 641 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 642 SmallVector<Value, 4> sizes; 643 SmallVector<Value, 8> params; 644 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 645 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 646 Value iter = genNewCall(rewriter, op, params); 647 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 648 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 649 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 650 SmallVector<Value> noArgs; 651 SmallVector<Type> noTypes; 652 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 653 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 654 rewriter.setInsertionPointToEnd(before); 655 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 656 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 657 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 658 rewriter.setInsertionPointToStart(after); 659 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 660 rewriter.create<scf::YieldOp>(loc); 661 rewriter.setInsertionPointAfter(whileOp); 662 rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, dst); 663 return success(); 664 } 665 if (!encDst && !encSrc) { 666 // dense => dense 667 return failure(); 668 } 669 // This is a dense => sparse conversion or a sparse constant in COO => 670 // sparse conversion, which is handled as follows: 671 // t = newSparseCOO() 672 // ...code to fill the COO tensor t... 673 // s = newSparseTensor(t) 674 // 675 // To fill the COO tensor from a dense tensor: 676 // for i1 in dim1 677 // .. 678 // for ik in dimk 679 // val = a[i1,..,ik] 680 // if val != 0 681 // t->add(val, [i1,..,ik], [p1,..,pk]) 682 // 683 // To fill the COO tensor from a sparse constant in COO format: 684 // for i in range(NNZ) 685 // val = values[i] 686 // [i1,..,ik] = indices[i] 687 // t->add(val, [i1,..,ik], [p1,..,pk]) 688 // 689 // Note that the dense tensor traversal code is actually implemented 690 // using MLIR IR to avoid having to expose too much low-level 691 // memref traversal details to the runtime support library. 692 // Also note that the code below only generates the "new" ops and 693 // the loop-nest per se; whereas the entire body of the innermost 694 // loop is generated by genAddElt(). 695 ShapedType stp = resType.cast<ShapedType>(); 696 unsigned rank = stp.getRank(); 697 SmallVector<Value, 4> sizes; 698 SmallVector<Value, 8> params; 699 sizesFromSrc(rewriter, sizes, loc, src); 700 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 701 Value ptr = genNewCall(rewriter, op, params); 702 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 703 Value perm = params[2]; 704 SmallVector<Value> lo; 705 SmallVector<Value> hi; 706 SmallVector<Value> st; 707 Value zero = constantIndex(rewriter, loc, 0); 708 Value one = constantIndex(rewriter, loc, 1); 709 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 710 bool isCOOConstant = indicesValues.hasValue(); 711 Value indices; 712 Value values; 713 if (isCOOConstant) { 714 indices = indicesValues->first; 715 values = indicesValues->second; 716 lo.push_back(zero); 717 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 718 st.push_back(one); 719 } else { 720 for (unsigned i = 0; i < rank; i++) { 721 lo.push_back(zero); 722 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 723 st.push_back(one); 724 } 725 } 726 Type eltType = stp.getElementType(); 727 scf::buildLoopNest( 728 rewriter, op.getLoc(), lo, hi, st, {}, 729 [&](OpBuilder &builder, Location loc, ValueRange ivs, 730 ValueRange args) -> scf::ValueVector { 731 Value val; 732 if (isCOOConstant) 733 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 734 ivs, rank); 735 else 736 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 737 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 738 return {}; 739 }); 740 // Final call to construct sparse tensor storage. 741 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 742 params[7] = ptr; 743 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 744 return success(); 745 } 746 }; 747 748 /// Sparse conversion rule for the release operator. 749 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 750 public: 751 using OpConversionPattern::OpConversionPattern; 752 LogicalResult 753 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 754 ConversionPatternRewriter &rewriter) const override { 755 StringRef name = "delSparseTensor"; 756 TypeRange none; 757 auto fn = getFunc(op, name, none, adaptor.getOperands()); 758 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 759 rewriter.eraseOp(op); 760 return success(); 761 } 762 }; 763 764 /// Sparse conversion rule for pointer accesses. 765 class SparseTensorToPointersConverter 766 : public OpConversionPattern<ToPointersOp> { 767 public: 768 using OpConversionPattern::OpConversionPattern; 769 LogicalResult 770 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 771 ConversionPatternRewriter &rewriter) const override { 772 Type resType = op.getType(); 773 Type eltType = resType.cast<ShapedType>().getElementType(); 774 StringRef name; 775 if (eltType.isIndex()) 776 name = "sparsePointers"; 777 else if (eltType.isInteger(64)) 778 name = "sparsePointers64"; 779 else if (eltType.isInteger(32)) 780 name = "sparsePointers32"; 781 else if (eltType.isInteger(16)) 782 name = "sparsePointers16"; 783 else if (eltType.isInteger(8)) 784 name = "sparsePointers8"; 785 else 786 return failure(); 787 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 788 /*emitCInterface=*/true); 789 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 790 return success(); 791 } 792 }; 793 794 /// Sparse conversion rule for index accesses. 795 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 796 public: 797 using OpConversionPattern::OpConversionPattern; 798 LogicalResult 799 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 800 ConversionPatternRewriter &rewriter) const override { 801 Type resType = op.getType(); 802 Type eltType = resType.cast<ShapedType>().getElementType(); 803 StringRef name; 804 if (eltType.isIndex()) 805 name = "sparseIndices"; 806 else if (eltType.isInteger(64)) 807 name = "sparseIndices64"; 808 else if (eltType.isInteger(32)) 809 name = "sparseIndices32"; 810 else if (eltType.isInteger(16)) 811 name = "sparseIndices16"; 812 else if (eltType.isInteger(8)) 813 name = "sparseIndices8"; 814 else 815 return failure(); 816 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 817 /*emitCInterface=*/true); 818 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 819 return success(); 820 } 821 }; 822 823 /// Sparse conversion rule for value accesses. 824 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 825 public: 826 using OpConversionPattern::OpConversionPattern; 827 LogicalResult 828 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 829 ConversionPatternRewriter &rewriter) const override { 830 Type resType = op.getType(); 831 Type eltType = resType.cast<ShapedType>().getElementType(); 832 StringRef name; 833 if (eltType.isF64()) 834 name = "sparseValuesF64"; 835 else if (eltType.isF32()) 836 name = "sparseValuesF32"; 837 else if (eltType.isInteger(64)) 838 name = "sparseValuesI64"; 839 else if (eltType.isInteger(32)) 840 name = "sparseValuesI32"; 841 else if (eltType.isInteger(16)) 842 name = "sparseValuesI16"; 843 else if (eltType.isInteger(8)) 844 name = "sparseValuesI8"; 845 else 846 return failure(); 847 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 848 /*emitCInterface=*/true); 849 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 850 return success(); 851 } 852 }; 853 854 /// Sparse conversion rule for tensor rematerialization. 855 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 856 public: 857 using OpConversionPattern::OpConversionPattern; 858 LogicalResult 859 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 860 ConversionPatternRewriter &rewriter) const override { 861 if (op.hasInserts()) { 862 // Finalize any pending insertions. 863 StringRef name = "endInsert"; 864 TypeRange noTp; 865 auto fn = getFunc(op, name, noTp, adaptor.getOperands()); 866 rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands()); 867 } 868 rewriter.replaceOp(op, adaptor.getOperands()); 869 return success(); 870 } 871 }; 872 873 /// Sparse conversion rule for inserting in lexicographic index order. 874 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 875 public: 876 using OpConversionPattern::OpConversionPattern; 877 LogicalResult 878 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 879 ConversionPatternRewriter &rewriter) const override { 880 Type srcType = op.tensor().getType(); 881 Type eltType = srcType.cast<ShapedType>().getElementType(); 882 StringRef name; 883 if (eltType.isF64()) 884 name = "lexInsertF64"; 885 else if (eltType.isF32()) 886 name = "lexInsertF32"; 887 else if (eltType.isInteger(64)) 888 name = "lexInsertI64"; 889 else if (eltType.isInteger(32)) 890 name = "lexInsertI32"; 891 else if (eltType.isInteger(16)) 892 name = "lexInsertI16"; 893 else if (eltType.isInteger(8)) 894 name = "lexInsertI8"; 895 else 896 llvm_unreachable("Unknown element type"); 897 TypeRange noTp; 898 auto fn = 899 getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true); 900 rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands()); 901 return success(); 902 } 903 }; 904 905 } // namespace 906 907 //===----------------------------------------------------------------------===// 908 // Public method for populating conversion rules. 909 //===----------------------------------------------------------------------===// 910 911 /// Populates the given patterns list with conversion rules required for 912 /// the sparsification of linear algebra operations. 913 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 914 RewritePatternSet &patterns) { 915 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 916 SparseCastConverter, SparseTensorNewConverter, 917 SparseTensorInitConverter, SparseTensorConvertConverter, 918 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 919 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 920 SparseTensorLoadConverter, SparseTensorLexInsertConverter>( 921 typeConverter, patterns.getContext()); 922 } 923