1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 namespace { 33 34 //===----------------------------------------------------------------------===// 35 // Helper methods. 36 //===----------------------------------------------------------------------===// 37 38 /// Generates a constant zero of the given type. 39 inline static Value constantZero(ConversionPatternRewriter &rewriter, 40 Location loc, Type t) { 41 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 42 } 43 44 /// Generates a constant of `index` type. 45 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 46 Location loc, int64_t i) { 47 return rewriter.create<arith::ConstantIndexOp>(loc, i); 48 } 49 50 /// Generates a constant of `i32` type. 51 inline static Value constantI32(ConversionPatternRewriter &rewriter, 52 Location loc, int32_t i) { 53 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 54 } 55 56 /// Generates a constant of `i8` type. 57 inline static Value constantI8(ConversionPatternRewriter &rewriter, 58 Location loc, int8_t i) { 59 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 60 } 61 62 /// Generates a constant of the given `Action`. 63 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 64 Action action) { 65 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 66 } 67 68 /// Generates a constant of the internal type encoding for overhead storage. 69 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 70 Location loc, unsigned width) { 71 OverheadType sec; 72 switch (width) { 73 default: 74 sec = OverheadType::kU64; 75 break; 76 case 32: 77 sec = OverheadType::kU32; 78 break; 79 case 16: 80 sec = OverheadType::kU16; 81 break; 82 case 8: 83 sec = OverheadType::kU8; 84 break; 85 } 86 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 87 } 88 89 /// Generates a constant of the internal type encoding for pointer 90 /// overhead storage. 91 static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, 92 Location loc, 93 SparseTensorEncodingAttr &enc) { 94 return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); 95 } 96 97 /// Generates a constant of the internal type encoding for index overhead 98 /// storage. 99 static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, 100 Location loc, 101 SparseTensorEncodingAttr &enc) { 102 return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); 103 } 104 105 /// Generates a constant of the internal type encoding for primary storage. 106 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 107 Location loc, Type tp) { 108 PrimaryType primary; 109 if (tp.isF64()) 110 primary = PrimaryType::kF64; 111 else if (tp.isF32()) 112 primary = PrimaryType::kF32; 113 else if (tp.isInteger(64)) 114 primary = PrimaryType::kI64; 115 else if (tp.isInteger(32)) 116 primary = PrimaryType::kI32; 117 else if (tp.isInteger(16)) 118 primary = PrimaryType::kI16; 119 else if (tp.isInteger(8)) 120 primary = PrimaryType::kI8; 121 else 122 llvm_unreachable("Unknown element type"); 123 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 124 } 125 126 /// Generates a constant of the internal dimension level type encoding. 127 static Value 128 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 129 SparseTensorEncodingAttr::DimLevelType dlt) { 130 DimLevelType dlt2; 131 switch (dlt) { 132 case SparseTensorEncodingAttr::DimLevelType::Dense: 133 dlt2 = DimLevelType::kDense; 134 break; 135 case SparseTensorEncodingAttr::DimLevelType::Compressed: 136 dlt2 = DimLevelType::kCompressed; 137 break; 138 case SparseTensorEncodingAttr::DimLevelType::Singleton: 139 dlt2 = DimLevelType::kSingleton; 140 break; 141 } 142 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 143 } 144 145 /// Returns a function reference (first hit also inserts into module). Sets 146 /// the "_emit_c_interface" on the function declaration when requested, 147 /// so that LLVM lowering generates a wrapper function that takes care 148 /// of ABI complications with passing in and returning MemRefs to C functions. 149 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 150 TypeRange resultType, ValueRange operands, 151 bool emitCInterface = false) { 152 MLIRContext *context = op->getContext(); 153 auto module = op->getParentOfType<ModuleOp>(); 154 auto result = SymbolRefAttr::get(context, name); 155 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 156 if (!func) { 157 OpBuilder moduleBuilder(module.getBodyRegion()); 158 func = moduleBuilder.create<FuncOp>( 159 op->getLoc(), name, 160 FunctionType::get(context, operands.getTypes(), resultType)); 161 func.setPrivate(); 162 if (emitCInterface) 163 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 164 } 165 return result; 166 } 167 168 /// Generates dimension size call. 169 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 170 SparseTensorEncodingAttr &enc, Value src, 171 int64_t idx) { 172 // Permute the index according to an optional dimension ordering. 173 if (AffineMap p = enc.getDimOrdering()) 174 idx = p.getPermutedPosition(idx); 175 // Generate the call. 176 Location loc = op->getLoc(); 177 StringRef name = "sparseDimSize"; 178 SmallVector<Value, 2> params; 179 params.push_back(src); 180 params.push_back(constantIndex(rewriter, loc, idx)); 181 Type iTp = rewriter.getIndexType(); 182 auto fn = getFunc(op, name, iTp, params); 183 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 184 } 185 186 /// Generates a call into the "swiss army knife" method of the sparse runtime 187 /// support library for materializing sparse tensors into the computation. 188 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 189 ArrayRef<Value> params) { 190 Location loc = op->getLoc(); 191 StringRef name = "newSparseTensor"; 192 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 193 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 194 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 195 return call.getResult(0); 196 } 197 198 /// Populates given sizes array from type. 199 static void sizesFromType(ConversionPatternRewriter &rewriter, 200 SmallVector<Value, 4> &sizes, Location loc, 201 ShapedType stp) { 202 auto shape = stp.getShape(); 203 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 204 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 205 sizes.push_back(constantIndex(rewriter, loc, s)); 206 } 207 } 208 209 /// Populates given sizes array from source. 210 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 211 SmallVector<Value, 4> &sizes, Location loc, 212 Value src) { 213 ShapedType stp = src.getType().cast<ShapedType>(); 214 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 215 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 216 } 217 218 /// Populates given sizes array from type (for static sizes) and from 219 /// an already converted into opague pointer source (for dynamic sizes). 220 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 221 SmallVector<Value, 4> &sizes, Operation *op, 222 SparseTensorEncodingAttr &enc, ShapedType stp, 223 Value src) { 224 auto shape = stp.getShape(); 225 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 226 if (shape[i] == ShapedType::kDynamicSize) 227 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 228 else 229 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 230 } 231 232 /// Generates an uninitialized temporary buffer of the given size and 233 /// type, but returns it as type `memref<? x $tp>` (rather than as type 234 /// `memref<$sz x $tp>`). 235 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 236 unsigned sz, Type tp) { 237 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 238 Value a = constantIndex(rewriter, loc, sz); 239 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 240 } 241 242 /// Generates an uninitialized temporary buffer with room for one value 243 /// of the given type, and returns the `memref<$tp>`. 244 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 245 Type tp) { 246 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 247 } 248 249 /// Generates a temporary buffer of the given type and given contents. 250 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 251 ArrayRef<Value> values) { 252 unsigned sz = values.size(); 253 assert(sz >= 1); 254 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 255 for (unsigned i = 0; i < sz; i++) { 256 Value idx = constantIndex(rewriter, loc, i); 257 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 258 } 259 return buffer; 260 } 261 262 /// Populates parameters required to call the "swiss army knife" method of the 263 /// sparse runtime support library for materializing sparse tensors into the 264 /// computation. 265 static void newParams(ConversionPatternRewriter &rewriter, 266 SmallVector<Value, 8> ¶ms, Operation *op, 267 SparseTensorEncodingAttr &enc, Action action, 268 ValueRange szs, Value ptr = Value()) { 269 Location loc = op->getLoc(); 270 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 271 unsigned sz = dlt.size(); 272 // Sparsity annotations. 273 SmallVector<Value, 4> attrs; 274 for (unsigned i = 0; i < sz; i++) 275 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 276 params.push_back(genBuffer(rewriter, loc, attrs)); 277 // Dimension sizes array of the enveloping tensor. Useful for either 278 // verification of external data, or for construction of internal data. 279 SmallVector<Value, 4> sizes; 280 for (Value s : szs) 281 sizes.push_back(s); 282 params.push_back(genBuffer(rewriter, loc, sizes)); 283 // Dimension order permutation array. This is the "identity" permutation by 284 // default, or otherwise the "reverse" permutation of a given ordering, so 285 // that indices can be mapped quickly to the right position. 286 SmallVector<Value, 4> rev(sz); 287 if (AffineMap p = enc.getDimOrdering()) { 288 for (unsigned i = 0; i < sz; i++) 289 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 290 } else { 291 for (unsigned i = 0; i < sz; i++) 292 rev[i] = constantIndex(rewriter, loc, i); 293 } 294 params.push_back(genBuffer(rewriter, loc, rev)); 295 // Secondary and primary types encoding. 296 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 297 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 298 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 299 params.push_back( 300 constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); 301 // User action and pointer. 302 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 303 if (!ptr) 304 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 305 params.push_back(constantAction(rewriter, loc, action)); 306 params.push_back(ptr); 307 } 308 309 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 310 /// For floating types, we use the "unordered" comparator (i.e., returns 311 /// true if `v` is NaN). 312 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 313 Value v) { 314 Type t = v.getType(); 315 Value zero = constantZero(rewriter, loc, t); 316 if (t.isa<FloatType>()) 317 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 318 zero); 319 if (t.isIntOrIndex()) 320 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 321 zero); 322 llvm_unreachable("Unknown element type"); 323 } 324 325 /// Generates the code to read the value from tensor[ivs], and conditionally 326 /// stores the indices ivs to the memory in ind. The generated code looks like 327 /// the following and the insertion point after this routine is inside the 328 /// if-then branch behind the assignment to ind. This is to ensure that the 329 /// addEltX call generated after is inside the if-then branch. 330 /// if (tensor[ivs]!=0) { 331 /// ind = ivs 332 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 333 Location loc, Value tensor, Value ind, 334 ValueRange ivs) { 335 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 336 Value cond = genIsNonzero(rewriter, loc, val); 337 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 338 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 339 unsigned i = 0; 340 for (auto iv : ivs) { 341 Value idx = constantIndex(rewriter, loc, i++); 342 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 343 } 344 return val; 345 } 346 347 /// Generates a call that adds one element to a coordinate scheme. 348 /// In particular, this generates code like the following: 349 /// val = a[i1,..,ik]; 350 /// if val != 0 351 /// t->add(val, [i1,..,ik], [p1,..,pk]); 352 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 353 Type eltType, Value ptr, Value val, Value ind, 354 Value perm) { 355 Location loc = op->getLoc(); 356 StringRef name; 357 if (eltType.isF64()) 358 name = "addEltF64"; 359 else if (eltType.isF32()) 360 name = "addEltF32"; 361 else if (eltType.isInteger(64)) 362 name = "addEltI64"; 363 else if (eltType.isInteger(32)) 364 name = "addEltI32"; 365 else if (eltType.isInteger(16)) 366 name = "addEltI16"; 367 else if (eltType.isInteger(8)) 368 name = "addEltI8"; 369 else 370 llvm_unreachable("Unknown element type"); 371 SmallVector<Value, 8> params; 372 params.push_back(ptr); 373 params.push_back(val); 374 params.push_back(ind); 375 params.push_back(perm); 376 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 377 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 378 rewriter.create<CallOp>(loc, pTp, fn, params); 379 } 380 381 /// Generates a call to `iter->getNext()`. If there is a next element, 382 /// then it is copied into the out-parameters `ind` and `elemPtr`, 383 /// and the return value is true. If there isn't a next element, then 384 /// the memory for `iter` is freed and the return value is false. 385 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 386 Value iter, Value ind, Value elemPtr) { 387 Location loc = op->getLoc(); 388 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 389 StringRef name; 390 if (elemTp.isF64()) 391 name = "getNextF64"; 392 else if (elemTp.isF32()) 393 name = "getNextF32"; 394 else if (elemTp.isInteger(64)) 395 name = "getNextI64"; 396 else if (elemTp.isInteger(32)) 397 name = "getNextI32"; 398 else if (elemTp.isInteger(16)) 399 name = "getNextI16"; 400 else if (elemTp.isInteger(8)) 401 name = "getNextI8"; 402 else 403 llvm_unreachable("Unknown element type"); 404 SmallVector<Value, 3> params; 405 params.push_back(iter); 406 params.push_back(ind); 407 params.push_back(elemPtr); 408 Type i1 = rewriter.getI1Type(); 409 auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); 410 auto call = rewriter.create<CallOp>(loc, i1, fn, params); 411 return call.getResult(0); 412 } 413 414 /// If the tensor is a sparse constant, generates and returns the pair of 415 /// the constants for the indices and the values. 416 static Optional<std::pair<Value, Value>> 417 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 418 Value tensor) { 419 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 420 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 421 DenseElementsAttr indicesAttr = attr.getIndices(); 422 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 423 DenseElementsAttr valuesAttr = attr.getValues(); 424 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 425 return std::make_pair(indices, values); 426 } 427 } 428 return {}; 429 } 430 431 /// Generates the code to copy the index at indices[ivs] to ind, and return 432 /// the value at value[ivs]. 433 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 434 Location loc, Value indices, 435 Value values, Value ind, ValueRange ivs, 436 unsigned rank) { 437 for (unsigned i = 0; i < rank; i++) { 438 Value idx = constantIndex(rewriter, loc, i); 439 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 440 ValueRange{ivs[0], idx}); 441 val = 442 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 443 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 444 } 445 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 446 } 447 448 /// Generates code to allocate a tensor of the given type, and zero 449 /// initialize it. If the tensor type has any dynamic sizes, then the 450 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 451 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 452 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 453 RankedTensorType tensorTp, ValueRange sizes) { 454 Type elemTp = tensorTp.getElementType(); 455 auto shape = tensorTp.getShape(); 456 auto memTp = MemRefType::get(shape, elemTp); 457 SmallVector<Value> dynamicSizes; 458 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 459 if (shape[i] == ShapedType::kDynamicSize) 460 dynamicSizes.push_back(sizes[i]); 461 } 462 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 463 Value zero = constantZero(rewriter, loc, elemTp); 464 rewriter.create<linalg::FillOp>(loc, zero, mem).result(); 465 return mem; 466 } 467 468 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 469 /// the tensor created by allocDenseTensor(). The `rank` is the rank 470 /// of the `tensor` and the length of `ind`. 471 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 472 Location loc, Value elemPtr, 473 Value tensor, unsigned rank, 474 Value ind) { 475 SmallVector<Value, 4> ivs; 476 ivs.reserve(rank); 477 for (unsigned i = 0; i < rank; i++) { 478 Value idx = constantIndex(rewriter, loc, i); 479 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 480 } 481 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 482 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // Conversion rules. 487 //===----------------------------------------------------------------------===// 488 489 /// Sparse conversion rule for returns. 490 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 491 public: 492 using OpConversionPattern::OpConversionPattern; 493 LogicalResult 494 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 495 ConversionPatternRewriter &rewriter) const override { 496 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 497 return success(); 498 } 499 }; 500 501 /// Sparse conversion rule for dimension accesses. 502 class SparseTensorToDimSizeConverter 503 : public OpConversionPattern<tensor::DimOp> { 504 public: 505 using OpConversionPattern::OpConversionPattern; 506 LogicalResult 507 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 508 ConversionPatternRewriter &rewriter) const override { 509 // Only rewrite annotated DimOp with constant index. 510 auto enc = getSparseTensorEncoding(op.source().getType()); 511 if (!enc) 512 return failure(); 513 Optional<int64_t> index = op.getConstantIndex(); 514 if (!index.hasValue()) 515 return failure(); 516 // Generate the call. 517 Value src = adaptor.getOperands()[0]; 518 int64_t idx = index.getValue(); 519 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 520 return success(); 521 } 522 }; 523 524 /// Sparse conversion rule for trivial tensor casts. 525 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 526 using OpConversionPattern::OpConversionPattern; 527 LogicalResult 528 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 529 ConversionPatternRewriter &rewriter) const override { 530 // Only rewrite identically annotated source/dest. 531 auto encDst = getSparseTensorEncoding(op.getType()); 532 auto encSrc = getSparseTensorEncoding(op.source().getType()); 533 if (!encDst || encDst != encSrc) 534 return failure(); 535 rewriter.replaceOp(op, adaptor.getOperands()); 536 return success(); 537 } 538 }; 539 540 /// Sparse conversion rule for the new operator. 541 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 542 using OpConversionPattern::OpConversionPattern; 543 LogicalResult 544 matchAndRewrite(NewOp op, OpAdaptor adaptor, 545 ConversionPatternRewriter &rewriter) const override { 546 Type resType = op.getType(); 547 auto enc = getSparseTensorEncoding(resType); 548 if (!enc) 549 return failure(); 550 // Generate the call to construct tensor from ptr. The sizes are 551 // inferred from the result type of the new operator. 552 SmallVector<Value, 4> sizes; 553 SmallVector<Value, 8> params; 554 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 555 Value ptr = adaptor.getOperands()[0]; 556 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 557 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 558 return success(); 559 } 560 }; 561 562 /// Sparse conversion rule for the init operator. 563 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 564 using OpConversionPattern::OpConversionPattern; 565 LogicalResult 566 matchAndRewrite(InitOp op, OpAdaptor adaptor, 567 ConversionPatternRewriter &rewriter) const override { 568 Type resType = op.getType(); 569 auto enc = getSparseTensorEncoding(resType); 570 if (!enc) 571 return failure(); 572 // Generate the call to construct empty tensor. The sizes are 573 // explicitly defined by the arguments to the init operator. 574 SmallVector<Value, 8> params; 575 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 576 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 577 return success(); 578 } 579 }; 580 581 /// Sparse conversion rule for the convert operator. 582 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 583 using OpConversionPattern::OpConversionPattern; 584 LogicalResult 585 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 586 ConversionPatternRewriter &rewriter) const override { 587 Location loc = op->getLoc(); 588 Type resType = op.getType(); 589 Type srcType = op.source().getType(); 590 auto encDst = getSparseTensorEncoding(resType); 591 auto encSrc = getSparseTensorEncoding(srcType); 592 Value src = adaptor.getOperands()[0]; 593 if (encDst && encSrc) { 594 // This is a sparse => sparse conversion, which is handled as follows: 595 // t = src->toCOO(); ; src to COO in dst order 596 // dst = newSparseTensor(t) 597 // Using the coordinate scheme as an intermediate does not always 598 // yield the fastest conversion but avoids the need for a full 599 // O(N^2) conversion matrix. 600 if (encDst == encSrc) { 601 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 602 return success(); 603 } 604 SmallVector<Value, 4> sizes; 605 SmallVector<Value, 8> params; 606 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 607 src); 608 // Set up encoding with right mix of src and dst so that the two 609 // method calls can share most parameters, while still providing 610 // the correct sparsity information to either of them. 611 auto enc = SparseTensorEncodingAttr::get( 612 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 613 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 614 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 615 Value coo = genNewCall(rewriter, op, params); 616 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 617 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 618 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 619 params[7] = coo; 620 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 621 return success(); 622 } 623 if (!encDst && encSrc) { 624 // This is sparse => dense conversion, which is handled as follows: 625 // dst = new Tensor(0); 626 // iter = src->toCOO(); 627 // iter->startIterator(); 628 // while (elem = iter->getNext()) { 629 // dst[elem.indices] = elem.value; 630 // } 631 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 632 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 633 unsigned rank = dstTensorTp.getRank(); 634 Type elemTp = dstTensorTp.getElementType(); 635 // Fabricate a no-permutation encoding for newParams(). 636 // The pointer/index types must be those of `src`. 637 // The dimLevelTypes aren't actually used by Action::kToIterator. 638 encDst = SparseTensorEncodingAttr::get( 639 op->getContext(), 640 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 641 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 642 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 643 SmallVector<Value, 4> sizes; 644 SmallVector<Value, 8> params; 645 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 646 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 647 Value iter = genNewCall(rewriter, op, params); 648 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 649 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 650 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 651 SmallVector<Value> noArgs; 652 SmallVector<Type> noTypes; 653 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 654 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 655 rewriter.setInsertionPointToEnd(before); 656 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 657 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 658 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 659 rewriter.setInsertionPointToStart(after); 660 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 661 rewriter.create<scf::YieldOp>(loc); 662 rewriter.setInsertionPointAfter(whileOp); 663 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 664 return success(); 665 } 666 if (!encDst && !encSrc) { 667 // dense => dense 668 return failure(); 669 } 670 // This is a dense => sparse conversion or a sparse constant in COO => 671 // sparse conversion, which is handled as follows: 672 // t = newSparseCOO() 673 // ...code to fill the COO tensor t... 674 // s = newSparseTensor(t) 675 // 676 // To fill the COO tensor from a dense tensor: 677 // for i1 in dim1 678 // .. 679 // for ik in dimk 680 // val = a[i1,..,ik] 681 // if val != 0 682 // t->add(val, [i1,..,ik], [p1,..,pk]) 683 // 684 // To fill the COO tensor from a sparse constant in COO format: 685 // for i in range(NNZ) 686 // val = values[i] 687 // [i1,..,ik] = indices[i] 688 // t->add(val, [i1,..,ik], [p1,..,pk]) 689 // 690 // Note that the dense tensor traversal code is actually implemented 691 // using MLIR IR to avoid having to expose too much low-level 692 // memref traversal details to the runtime support library. 693 // Also note that the code below only generates the "new" ops and 694 // the loop-nest per se; whereas the entire body of the innermost 695 // loop is generated by genAddElt(). 696 ShapedType stp = resType.cast<ShapedType>(); 697 unsigned rank = stp.getRank(); 698 SmallVector<Value, 4> sizes; 699 SmallVector<Value, 8> params; 700 sizesFromSrc(rewriter, sizes, loc, src); 701 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 702 Value ptr = genNewCall(rewriter, op, params); 703 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 704 Value perm = params[2]; 705 SmallVector<Value> lo; 706 SmallVector<Value> hi; 707 SmallVector<Value> st; 708 Value zero = constantIndex(rewriter, loc, 0); 709 Value one = constantIndex(rewriter, loc, 1); 710 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 711 bool isCOOConstant = indicesValues.hasValue(); 712 Value indices; 713 Value values; 714 if (isCOOConstant) { 715 indices = indicesValues->first; 716 values = indicesValues->second; 717 lo.push_back(zero); 718 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 719 st.push_back(one); 720 } else { 721 for (unsigned i = 0; i < rank; i++) { 722 lo.push_back(zero); 723 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 724 st.push_back(one); 725 } 726 } 727 Type eltType = stp.getElementType(); 728 scf::buildLoopNest( 729 rewriter, op.getLoc(), lo, hi, st, {}, 730 [&](OpBuilder &builder, Location loc, ValueRange ivs, 731 ValueRange args) -> scf::ValueVector { 732 Value val; 733 if (isCOOConstant) 734 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 735 ivs, rank); 736 else 737 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 738 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 739 return {}; 740 }); 741 // Final call to construct sparse tensor storage. 742 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 743 params[7] = ptr; 744 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 745 return success(); 746 } 747 }; 748 749 /// Sparse conversion rule for the release operator. 750 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 751 public: 752 using OpConversionPattern::OpConversionPattern; 753 LogicalResult 754 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 755 ConversionPatternRewriter &rewriter) const override { 756 StringRef name = "delSparseTensor"; 757 TypeRange none; 758 auto fn = getFunc(op, name, none, adaptor.getOperands()); 759 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 760 rewriter.eraseOp(op); 761 return success(); 762 } 763 }; 764 765 /// Sparse conversion rule for pointer accesses. 766 class SparseTensorToPointersConverter 767 : public OpConversionPattern<ToPointersOp> { 768 public: 769 using OpConversionPattern::OpConversionPattern; 770 LogicalResult 771 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 772 ConversionPatternRewriter &rewriter) const override { 773 Type resType = op.getType(); 774 Type eltType = resType.cast<ShapedType>().getElementType(); 775 StringRef name; 776 if (eltType.isIndex()) 777 name = "sparsePointers"; 778 else if (eltType.isInteger(64)) 779 name = "sparsePointers64"; 780 else if (eltType.isInteger(32)) 781 name = "sparsePointers32"; 782 else if (eltType.isInteger(16)) 783 name = "sparsePointers16"; 784 else if (eltType.isInteger(8)) 785 name = "sparsePointers8"; 786 else 787 return failure(); 788 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 789 /*emitCInterface=*/true); 790 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 791 return success(); 792 } 793 }; 794 795 /// Sparse conversion rule for index accesses. 796 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 797 public: 798 using OpConversionPattern::OpConversionPattern; 799 LogicalResult 800 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 801 ConversionPatternRewriter &rewriter) const override { 802 Type resType = op.getType(); 803 Type eltType = resType.cast<ShapedType>().getElementType(); 804 StringRef name; 805 if (eltType.isIndex()) 806 name = "sparseIndices"; 807 else if (eltType.isInteger(64)) 808 name = "sparseIndices64"; 809 else if (eltType.isInteger(32)) 810 name = "sparseIndices32"; 811 else if (eltType.isInteger(16)) 812 name = "sparseIndices16"; 813 else if (eltType.isInteger(8)) 814 name = "sparseIndices8"; 815 else 816 return failure(); 817 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 818 /*emitCInterface=*/true); 819 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 820 return success(); 821 } 822 }; 823 824 /// Sparse conversion rule for value accesses. 825 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 826 public: 827 using OpConversionPattern::OpConversionPattern; 828 LogicalResult 829 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 830 ConversionPatternRewriter &rewriter) const override { 831 Type resType = op.getType(); 832 Type eltType = resType.cast<ShapedType>().getElementType(); 833 StringRef name; 834 if (eltType.isF64()) 835 name = "sparseValuesF64"; 836 else if (eltType.isF32()) 837 name = "sparseValuesF32"; 838 else if (eltType.isInteger(64)) 839 name = "sparseValuesI64"; 840 else if (eltType.isInteger(32)) 841 name = "sparseValuesI32"; 842 else if (eltType.isInteger(16)) 843 name = "sparseValuesI16"; 844 else if (eltType.isInteger(8)) 845 name = "sparseValuesI8"; 846 else 847 return failure(); 848 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 849 /*emitCInterface=*/true); 850 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 851 return success(); 852 } 853 }; 854 855 /// Sparse conversion rule for tensor rematerialization. 856 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 857 public: 858 using OpConversionPattern::OpConversionPattern; 859 LogicalResult 860 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 861 ConversionPatternRewriter &rewriter) const override { 862 if (op.hasInserts()) { 863 // Finalize any pending insertions. 864 StringRef name = "endInsert"; 865 TypeRange noTp; 866 auto fn = getFunc(op, name, noTp, adaptor.getOperands()); 867 rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands()); 868 } 869 rewriter.replaceOp(op, adaptor.getOperands()); 870 return success(); 871 } 872 }; 873 874 /// Sparse conversion rule for inserting in lexicographic index order. 875 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 876 public: 877 using OpConversionPattern::OpConversionPattern; 878 LogicalResult 879 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 880 ConversionPatternRewriter &rewriter) const override { 881 Type srcType = op.tensor().getType(); 882 Type eltType = srcType.cast<ShapedType>().getElementType(); 883 StringRef name; 884 if (eltType.isF64()) 885 name = "lexInsertF64"; 886 else if (eltType.isF32()) 887 name = "lexInsertF32"; 888 else if (eltType.isInteger(64)) 889 name = "lexInsertI64"; 890 else if (eltType.isInteger(32)) 891 name = "lexInsertI32"; 892 else if (eltType.isInteger(16)) 893 name = "lexInsertI16"; 894 else if (eltType.isInteger(8)) 895 name = "lexInsertI8"; 896 else 897 llvm_unreachable("Unknown element type"); 898 TypeRange noTp; 899 auto fn = 900 getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true); 901 rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands()); 902 return success(); 903 } 904 }; 905 906 } // namespace 907 908 //===----------------------------------------------------------------------===// 909 // Public method for populating conversion rules. 910 //===----------------------------------------------------------------------===// 911 912 /// Populates the given patterns list with conversion rules required for 913 /// the sparsification of linear algebra operations. 914 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 915 RewritePatternSet &patterns) { 916 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 917 SparseCastConverter, SparseTensorNewConverter, 918 SparseTensorInitConverter, SparseTensorConvertConverter, 919 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 920 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 921 SparseTensorLoadConverter, SparseTensorLexInsertConverter>( 922 typeConverter, patterns.getContext()); 923 } 924