1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 namespace { 33 34 //===----------------------------------------------------------------------===// 35 // Helper methods. 36 //===----------------------------------------------------------------------===// 37 38 /// Generates a constant zero of the given type. 39 inline static Value constantZero(ConversionPatternRewriter &rewriter, 40 Location loc, Type t) { 41 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 42 } 43 44 /// Generates a constant of `index` type. 45 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 46 Location loc, int64_t i) { 47 return rewriter.create<arith::ConstantIndexOp>(loc, i); 48 } 49 50 /// Generates a constant of `i32` type. 51 inline static Value constantI32(ConversionPatternRewriter &rewriter, 52 Location loc, int32_t i) { 53 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 54 } 55 56 /// Generates a constant of `i8` type. 57 inline static Value constantI8(ConversionPatternRewriter &rewriter, 58 Location loc, int8_t i) { 59 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 60 } 61 62 /// Generates a constant of the given `Action`. 63 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 64 Action action) { 65 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 66 } 67 68 /// Generates a constant of the internal type encoding for overhead storage. 69 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 70 Location loc, unsigned width) { 71 OverheadType sec; 72 switch (width) { 73 default: 74 sec = OverheadType::kU64; 75 break; 76 case 32: 77 sec = OverheadType::kU32; 78 break; 79 case 16: 80 sec = OverheadType::kU16; 81 break; 82 case 8: 83 sec = OverheadType::kU8; 84 break; 85 } 86 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 87 } 88 89 /// Generates a constant of the internal type encoding for pointer 90 /// overhead storage. 91 static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, 92 Location loc, 93 SparseTensorEncodingAttr &enc) { 94 return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); 95 } 96 97 /// Generates a constant of the internal type encoding for index overhead 98 /// storage. 99 static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, 100 Location loc, 101 SparseTensorEncodingAttr &enc) { 102 return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); 103 } 104 105 /// Generates a constant of the internal type encoding for primary storage. 106 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 107 Location loc, Type tp) { 108 PrimaryType primary; 109 if (tp.isF64()) 110 primary = PrimaryType::kF64; 111 else if (tp.isF32()) 112 primary = PrimaryType::kF32; 113 else if (tp.isInteger(64)) 114 primary = PrimaryType::kI64; 115 else if (tp.isInteger(32)) 116 primary = PrimaryType::kI32; 117 else if (tp.isInteger(16)) 118 primary = PrimaryType::kI16; 119 else if (tp.isInteger(8)) 120 primary = PrimaryType::kI8; 121 else 122 llvm_unreachable("Unknown element type"); 123 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 124 } 125 126 /// Generates a constant of the internal dimension level type encoding. 127 static Value 128 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 129 SparseTensorEncodingAttr::DimLevelType dlt) { 130 DimLevelType dlt2; 131 switch (dlt) { 132 case SparseTensorEncodingAttr::DimLevelType::Dense: 133 dlt2 = DimLevelType::kDense; 134 break; 135 case SparseTensorEncodingAttr::DimLevelType::Compressed: 136 dlt2 = DimLevelType::kCompressed; 137 break; 138 case SparseTensorEncodingAttr::DimLevelType::Singleton: 139 dlt2 = DimLevelType::kSingleton; 140 break; 141 } 142 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 143 } 144 145 /// Returns the equivalent of `void*` for opaque arguments to the 146 /// execution engine. 147 static Type getOpaquePointerType(PatternRewriter &rewriter) { 148 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 149 } 150 151 /// Returns a function reference (first hit also inserts into module). Sets 152 /// the "_emit_c_interface" on the function declaration when requested, 153 /// so that LLVM lowering generates a wrapper function that takes care 154 /// of ABI complications with passing in and returning MemRefs to C functions. 155 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 156 TypeRange resultType, ValueRange operands, 157 bool emitCInterface) { 158 MLIRContext *context = op->getContext(); 159 auto module = op->getParentOfType<ModuleOp>(); 160 auto result = SymbolRefAttr::get(context, name); 161 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 162 if (!func) { 163 OpBuilder moduleBuilder(module.getBodyRegion()); 164 func = moduleBuilder.create<FuncOp>( 165 op->getLoc(), name, 166 FunctionType::get(context, operands.getTypes(), resultType)); 167 func.setPrivate(); 168 if (emitCInterface) 169 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 170 } 171 return result; 172 } 173 174 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 175 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 176 TypeRange resultType, ValueRange operands, 177 bool emitCInterface = false) { 178 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 179 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 180 } 181 182 /// Replaces the `op` with a `CallOp` to the function reference returned 183 /// by `getFunc()`. 184 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 185 StringRef name, TypeRange resultType, 186 ValueRange operands, 187 bool emitCInterface = false) { 188 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 189 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 190 } 191 192 /// Generates dimension size call. 193 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 194 SparseTensorEncodingAttr &enc, Value src, 195 int64_t idx) { 196 // Permute the index according to an optional dimension ordering. 197 if (AffineMap p = enc.getDimOrdering()) 198 idx = p.getPermutedPosition(idx); 199 // Generate the call. 200 StringRef name = "sparseDimSize"; 201 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 202 Type iTp = rewriter.getIndexType(); 203 return createFuncCall(rewriter, op, name, iTp, params).getResult(0); 204 } 205 206 /// Generates a call into the "swiss army knife" method of the sparse runtime 207 /// support library for materializing sparse tensors into the computation. 208 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 209 ArrayRef<Value> params) { 210 StringRef name = "newSparseTensor"; 211 Type pTp = getOpaquePointerType(rewriter); 212 auto call = createFuncCall(rewriter, op, name, pTp, params, 213 /*emitCInterface=*/true); 214 return call.getResult(0); 215 } 216 217 /// Populates given sizes array from type. 218 static void sizesFromType(ConversionPatternRewriter &rewriter, 219 SmallVector<Value, 4> &sizes, Location loc, 220 ShapedType stp) { 221 auto shape = stp.getShape(); 222 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 223 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 224 sizes.push_back(constantIndex(rewriter, loc, s)); 225 } 226 } 227 228 /// Populates given sizes array from source. 229 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 230 SmallVector<Value, 4> &sizes, Location loc, 231 Value src) { 232 unsigned rank = src.getType().cast<ShapedType>().getRank(); 233 for (unsigned i = 0; i < rank; i++) 234 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 235 } 236 237 /// Populates given sizes array from type (for static sizes) and from 238 /// an already converted into opague pointer source (for dynamic sizes). 239 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 240 SmallVector<Value, 4> &sizes, Operation *op, 241 SparseTensorEncodingAttr &enc, ShapedType stp, 242 Value src) { 243 Location loc = op->getLoc(); 244 auto shape = stp.getShape(); 245 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 246 if (shape[i] == ShapedType::kDynamicSize) 247 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 248 else 249 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 250 } 251 252 /// Generates an uninitialized temporary buffer of the given size and 253 /// type, but returns it as type `memref<? x $tp>` (rather than as type 254 /// `memref<$sz x $tp>`). 255 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 256 unsigned sz, Type tp) { 257 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 258 Value a = constantIndex(rewriter, loc, sz); 259 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 260 } 261 262 /// Generates an uninitialized temporary buffer with room for one value 263 /// of the given type, and returns the `memref<$tp>`. 264 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 265 Type tp) { 266 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 267 } 268 269 /// Generates a temporary buffer of the given type and given contents. 270 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 271 ArrayRef<Value> values) { 272 unsigned sz = values.size(); 273 assert(sz >= 1); 274 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 275 for (unsigned i = 0; i < sz; i++) { 276 Value idx = constantIndex(rewriter, loc, i); 277 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 278 } 279 return buffer; 280 } 281 282 /// Populates parameters required to call the "swiss army knife" method of the 283 /// sparse runtime support library for materializing sparse tensors into the 284 /// computation. 285 static void newParams(ConversionPatternRewriter &rewriter, 286 SmallVector<Value, 8> ¶ms, Operation *op, 287 SparseTensorEncodingAttr &enc, Action action, 288 ValueRange szs, Value ptr = Value()) { 289 Location loc = op->getLoc(); 290 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 291 unsigned sz = dlt.size(); 292 // Sparsity annotations. 293 SmallVector<Value, 4> attrs; 294 for (unsigned i = 0; i < sz; i++) 295 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 296 params.push_back(genBuffer(rewriter, loc, attrs)); 297 // Dimension sizes array of the enveloping tensor. Useful for either 298 // verification of external data, or for construction of internal data. 299 SmallVector<Value, 4> sizes; 300 for (Value s : szs) 301 sizes.push_back(s); 302 params.push_back(genBuffer(rewriter, loc, sizes)); 303 // Dimension order permutation array. This is the "identity" permutation by 304 // default, or otherwise the "reverse" permutation of a given ordering, so 305 // that indices can be mapped quickly to the right position. 306 SmallVector<Value, 4> rev(sz); 307 if (AffineMap p = enc.getDimOrdering()) { 308 for (unsigned i = 0; i < sz; i++) 309 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 310 } else { 311 for (unsigned i = 0; i < sz; i++) 312 rev[i] = constantIndex(rewriter, loc, i); 313 } 314 params.push_back(genBuffer(rewriter, loc, rev)); 315 // Secondary and primary types encoding. 316 Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType(); 317 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 318 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 319 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 320 // User action. 321 params.push_back(constantAction(rewriter, loc, action)); 322 // Payload pointer. 323 if (!ptr) 324 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 325 params.push_back(ptr); 326 } 327 328 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 329 /// For floating types, we use the "unordered" comparator (i.e., returns 330 /// true if `v` is NaN). 331 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 332 Value v) { 333 Type t = v.getType(); 334 Value zero = constantZero(rewriter, loc, t); 335 if (t.isa<FloatType>()) 336 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 337 zero); 338 if (t.isIntOrIndex()) 339 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 340 zero); 341 llvm_unreachable("Unknown element type"); 342 } 343 344 /// Generates the code to read the value from tensor[ivs], and conditionally 345 /// stores the indices ivs to the memory in ind. The generated code looks like 346 /// the following and the insertion point after this routine is inside the 347 /// if-then branch behind the assignment to ind. This is to ensure that the 348 /// addEltX call generated after is inside the if-then branch. 349 /// if (tensor[ivs]!=0) { 350 /// ind = ivs 351 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 352 Location loc, Value tensor, Value ind, 353 ValueRange ivs) { 354 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 355 Value cond = genIsNonzero(rewriter, loc, val); 356 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 357 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 358 unsigned i = 0; 359 for (auto iv : ivs) { 360 Value idx = constantIndex(rewriter, loc, i++); 361 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 362 } 363 return val; 364 } 365 366 /// Generates a call that adds one element to a coordinate scheme. 367 /// In particular, this generates code like the following: 368 /// val = a[i1,..,ik]; 369 /// if val != 0 370 /// t->add(val, [i1,..,ik], [p1,..,pk]); 371 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 372 Type eltType, Value ptr, Value val, Value ind, 373 Value perm) { 374 StringRef name; 375 if (eltType.isF64()) 376 name = "addEltF64"; 377 else if (eltType.isF32()) 378 name = "addEltF32"; 379 else if (eltType.isInteger(64)) 380 name = "addEltI64"; 381 else if (eltType.isInteger(32)) 382 name = "addEltI32"; 383 else if (eltType.isInteger(16)) 384 name = "addEltI16"; 385 else if (eltType.isInteger(8)) 386 name = "addEltI8"; 387 else 388 llvm_unreachable("Unknown element type"); 389 SmallVector<Value, 4> params{ptr, val, ind, perm}; 390 Type pTp = getOpaquePointerType(rewriter); 391 createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true); 392 } 393 394 /// Generates a call to `iter->getNext()`. If there is a next element, 395 /// then it is copied into the out-parameters `ind` and `elemPtr`, 396 /// and the return value is true. If there isn't a next element, then 397 /// the memory for `iter` is freed and the return value is false. 398 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 399 Value iter, Value ind, Value elemPtr) { 400 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 401 StringRef name; 402 if (elemTp.isF64()) 403 name = "getNextF64"; 404 else if (elemTp.isF32()) 405 name = "getNextF32"; 406 else if (elemTp.isInteger(64)) 407 name = "getNextI64"; 408 else if (elemTp.isInteger(32)) 409 name = "getNextI32"; 410 else if (elemTp.isInteger(16)) 411 name = "getNextI16"; 412 else if (elemTp.isInteger(8)) 413 name = "getNextI8"; 414 else 415 llvm_unreachable("Unknown element type"); 416 SmallVector<Value, 3> params{iter, ind, elemPtr}; 417 Type i1 = rewriter.getI1Type(); 418 auto call = createFuncCall(rewriter, op, name, i1, params, 419 /*emitCInterface=*/true); 420 return call.getResult(0); 421 } 422 423 /// If the tensor is a sparse constant, generates and returns the pair of 424 /// the constants for the indices and the values. 425 static Optional<std::pair<Value, Value>> 426 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 427 Value tensor) { 428 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 429 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 430 DenseElementsAttr indicesAttr = attr.getIndices(); 431 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 432 DenseElementsAttr valuesAttr = attr.getValues(); 433 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 434 return std::make_pair(indices, values); 435 } 436 } 437 return {}; 438 } 439 440 /// Generates the code to copy the index at indices[ivs] to ind, and return 441 /// the value at value[ivs]. 442 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 443 Location loc, Value indices, 444 Value values, Value ind, ValueRange ivs, 445 unsigned rank) { 446 for (unsigned i = 0; i < rank; i++) { 447 Value idx = constantIndex(rewriter, loc, i); 448 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 449 ValueRange{ivs[0], idx}); 450 val = 451 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 452 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 453 } 454 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 455 } 456 457 /// Generates code to allocate a tensor of the given type, and zero 458 /// initialize it. If the tensor type has any dynamic sizes, then the 459 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 460 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 461 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 462 RankedTensorType tensorTp, ValueRange sizes) { 463 Type elemTp = tensorTp.getElementType(); 464 auto shape = tensorTp.getShape(); 465 auto memTp = MemRefType::get(shape, elemTp); 466 SmallVector<Value> dynamicSizes; 467 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 468 if (shape[i] == ShapedType::kDynamicSize) 469 dynamicSizes.push_back(sizes[i]); 470 } 471 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 472 Value zero = constantZero(rewriter, loc, elemTp); 473 rewriter.create<linalg::FillOp>(loc, zero, mem); 474 return mem; 475 } 476 477 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 478 /// the tensor created by allocDenseTensor(). The `rank` is the rank 479 /// of the `tensor` and the length of `ind`. 480 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 481 Location loc, Value elemPtr, 482 Value tensor, unsigned rank, 483 Value ind) { 484 SmallVector<Value, 4> ivs; 485 ivs.reserve(rank); 486 for (unsigned i = 0; i < rank; i++) { 487 Value idx = constantIndex(rewriter, loc, i); 488 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 489 } 490 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 491 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 492 } 493 494 //===----------------------------------------------------------------------===// 495 // Conversion rules. 496 //===----------------------------------------------------------------------===// 497 498 /// Sparse conversion rule for returns. 499 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 500 public: 501 using OpConversionPattern::OpConversionPattern; 502 LogicalResult 503 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 504 ConversionPatternRewriter &rewriter) const override { 505 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 506 return success(); 507 } 508 }; 509 510 /// Sparse conversion rule for dimension accesses. 511 class SparseTensorToDimSizeConverter 512 : public OpConversionPattern<tensor::DimOp> { 513 public: 514 using OpConversionPattern::OpConversionPattern; 515 LogicalResult 516 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 517 ConversionPatternRewriter &rewriter) const override { 518 // Only rewrite annotated DimOp with constant index. 519 auto enc = getSparseTensorEncoding(op.source().getType()); 520 if (!enc) 521 return failure(); 522 Optional<int64_t> index = op.getConstantIndex(); 523 if (!index.hasValue()) 524 return failure(); 525 // Generate the call. 526 Value src = adaptor.getOperands()[0]; 527 int64_t idx = index.getValue(); 528 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 529 return success(); 530 } 531 }; 532 533 /// Sparse conversion rule for trivial tensor casts. 534 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 535 using OpConversionPattern::OpConversionPattern; 536 LogicalResult 537 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const override { 539 // Only rewrite identically annotated source/dest. 540 auto encDst = getSparseTensorEncoding(op.getType()); 541 auto encSrc = getSparseTensorEncoding(op.source().getType()); 542 if (!encDst || encDst != encSrc) 543 return failure(); 544 rewriter.replaceOp(op, adaptor.getOperands()); 545 return success(); 546 } 547 }; 548 549 /// Sparse conversion rule for the new operator. 550 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 551 using OpConversionPattern::OpConversionPattern; 552 LogicalResult 553 matchAndRewrite(NewOp op, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const override { 555 Type resType = op.getType(); 556 auto enc = getSparseTensorEncoding(resType); 557 if (!enc) 558 return failure(); 559 // Generate the call to construct tensor from ptr. The sizes are 560 // inferred from the result type of the new operator. 561 SmallVector<Value, 4> sizes; 562 SmallVector<Value, 8> params; 563 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 564 Value ptr = adaptor.getOperands()[0]; 565 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 566 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 567 return success(); 568 } 569 }; 570 571 /// Sparse conversion rule for the init operator. 572 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 573 using OpConversionPattern::OpConversionPattern; 574 LogicalResult 575 matchAndRewrite(InitOp op, OpAdaptor adaptor, 576 ConversionPatternRewriter &rewriter) const override { 577 Type resType = op.getType(); 578 auto enc = getSparseTensorEncoding(resType); 579 if (!enc) 580 return failure(); 581 // Generate the call to construct empty tensor. The sizes are 582 // explicitly defined by the arguments to the init operator. 583 SmallVector<Value, 8> params; 584 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 585 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 586 return success(); 587 } 588 }; 589 590 /// Sparse conversion rule for the convert operator. 591 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 592 using OpConversionPattern::OpConversionPattern; 593 LogicalResult 594 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 595 ConversionPatternRewriter &rewriter) const override { 596 Location loc = op->getLoc(); 597 Type resType = op.getType(); 598 Type srcType = op.source().getType(); 599 auto encDst = getSparseTensorEncoding(resType); 600 auto encSrc = getSparseTensorEncoding(srcType); 601 Value src = adaptor.getOperands()[0]; 602 if (encDst && encSrc) { 603 // This is a sparse => sparse conversion, which is handled as follows: 604 // t = src->toCOO(); ; src to COO in dst order 605 // dst = newSparseTensor(t) 606 // Using the coordinate scheme as an intermediate does not always 607 // yield the fastest conversion but avoids the need for a full 608 // O(N^2) conversion matrix. 609 if (encDst == encSrc) { 610 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 611 return success(); 612 } 613 SmallVector<Value, 4> sizes; 614 SmallVector<Value, 8> params; 615 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 616 src); 617 // Set up encoding with right mix of src and dst so that the two 618 // method calls can share most parameters, while still providing 619 // the correct sparsity information to either of them. 620 auto enc = SparseTensorEncodingAttr::get( 621 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 622 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 623 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 624 Value coo = genNewCall(rewriter, op, params); 625 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 626 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 627 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 628 params[7] = coo; 629 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 630 return success(); 631 } 632 if (!encDst && encSrc) { 633 // This is sparse => dense conversion, which is handled as follows: 634 // dst = new Tensor(0); 635 // iter = src->toCOO(); 636 // iter->startIterator(); 637 // while (elem = iter->getNext()) { 638 // dst[elem.indices] = elem.value; 639 // } 640 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 641 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 642 unsigned rank = dstTensorTp.getRank(); 643 Type elemTp = dstTensorTp.getElementType(); 644 // Fabricate a no-permutation encoding for newParams(). 645 // The pointer/index types must be those of `src`. 646 // The dimLevelTypes aren't actually used by Action::kToIterator. 647 encDst = SparseTensorEncodingAttr::get( 648 op->getContext(), 649 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 650 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 651 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 652 SmallVector<Value, 4> sizes; 653 SmallVector<Value, 8> params; 654 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 655 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 656 Value iter = genNewCall(rewriter, op, params); 657 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 658 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 659 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 660 SmallVector<Value> noArgs; 661 SmallVector<Type> noTypes; 662 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 663 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 664 rewriter.setInsertionPointToEnd(before); 665 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 666 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 667 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 668 rewriter.setInsertionPointToStart(after); 669 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 670 rewriter.create<scf::YieldOp>(loc); 671 rewriter.setInsertionPointAfter(whileOp); 672 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 673 return success(); 674 } 675 if (!encDst && !encSrc) { 676 // dense => dense 677 return failure(); 678 } 679 // This is a dense => sparse conversion or a sparse constant in COO => 680 // sparse conversion, which is handled as follows: 681 // t = newSparseCOO() 682 // ...code to fill the COO tensor t... 683 // s = newSparseTensor(t) 684 // 685 // To fill the COO tensor from a dense tensor: 686 // for i1 in dim1 687 // .. 688 // for ik in dimk 689 // val = a[i1,..,ik] 690 // if val != 0 691 // t->add(val, [i1,..,ik], [p1,..,pk]) 692 // 693 // To fill the COO tensor from a sparse constant in COO format: 694 // for i in range(NNZ) 695 // val = values[i] 696 // [i1,..,ik] = indices[i] 697 // t->add(val, [i1,..,ik], [p1,..,pk]) 698 // 699 // Note that the dense tensor traversal code is actually implemented 700 // using MLIR IR to avoid having to expose too much low-level 701 // memref traversal details to the runtime support library. 702 // Also note that the code below only generates the "new" ops and 703 // the loop-nest per se; whereas the entire body of the innermost 704 // loop is generated by genAddElt(). 705 ShapedType stp = resType.cast<ShapedType>(); 706 unsigned rank = stp.getRank(); 707 SmallVector<Value, 4> sizes; 708 SmallVector<Value, 8> params; 709 sizesFromSrc(rewriter, sizes, loc, src); 710 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 711 Value ptr = genNewCall(rewriter, op, params); 712 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 713 Value perm = params[2]; 714 SmallVector<Value> lo; 715 SmallVector<Value> hi; 716 SmallVector<Value> st; 717 Value zero = constantIndex(rewriter, loc, 0); 718 Value one = constantIndex(rewriter, loc, 1); 719 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 720 bool isCOOConstant = indicesValues.hasValue(); 721 Value indices; 722 Value values; 723 if (isCOOConstant) { 724 indices = indicesValues->first; 725 values = indicesValues->second; 726 lo.push_back(zero); 727 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 728 st.push_back(one); 729 } else { 730 for (unsigned i = 0; i < rank; i++) { 731 lo.push_back(zero); 732 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 733 st.push_back(one); 734 } 735 } 736 Type eltType = stp.getElementType(); 737 scf::buildLoopNest( 738 rewriter, op.getLoc(), lo, hi, st, {}, 739 [&](OpBuilder &builder, Location loc, ValueRange ivs, 740 ValueRange args) -> scf::ValueVector { 741 Value val; 742 if (isCOOConstant) 743 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 744 ivs, rank); 745 else 746 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 747 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 748 return {}; 749 }); 750 // Final call to construct sparse tensor storage. 751 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 752 params[7] = ptr; 753 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 754 return success(); 755 } 756 }; 757 758 /// Sparse conversion rule for the release operator. 759 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 760 public: 761 using OpConversionPattern::OpConversionPattern; 762 LogicalResult 763 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 764 ConversionPatternRewriter &rewriter) const override { 765 StringRef name = "delSparseTensor"; 766 TypeRange noTp; 767 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); 768 rewriter.eraseOp(op); 769 return success(); 770 } 771 }; 772 773 /// Sparse conversion rule for pointer accesses. 774 class SparseTensorToPointersConverter 775 : public OpConversionPattern<ToPointersOp> { 776 public: 777 using OpConversionPattern::OpConversionPattern; 778 LogicalResult 779 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 780 ConversionPatternRewriter &rewriter) const override { 781 Type resType = op.getType(); 782 Type eltType = resType.cast<ShapedType>().getElementType(); 783 StringRef name; 784 if (eltType.isIndex()) 785 name = "sparsePointers"; 786 else if (eltType.isInteger(64)) 787 name = "sparsePointers64"; 788 else if (eltType.isInteger(32)) 789 name = "sparsePointers32"; 790 else if (eltType.isInteger(16)) 791 name = "sparsePointers16"; 792 else if (eltType.isInteger(8)) 793 name = "sparsePointers8"; 794 else 795 return failure(); 796 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 797 /*emitCInterface=*/true); 798 return success(); 799 } 800 }; 801 802 /// Sparse conversion rule for index accesses. 803 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 804 public: 805 using OpConversionPattern::OpConversionPattern; 806 LogicalResult 807 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 808 ConversionPatternRewriter &rewriter) const override { 809 Type resType = op.getType(); 810 Type eltType = resType.cast<ShapedType>().getElementType(); 811 StringRef name; 812 if (eltType.isIndex()) 813 name = "sparseIndices"; 814 else if (eltType.isInteger(64)) 815 name = "sparseIndices64"; 816 else if (eltType.isInteger(32)) 817 name = "sparseIndices32"; 818 else if (eltType.isInteger(16)) 819 name = "sparseIndices16"; 820 else if (eltType.isInteger(8)) 821 name = "sparseIndices8"; 822 else 823 return failure(); 824 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 825 /*emitCInterface=*/true); 826 return success(); 827 } 828 }; 829 830 /// Sparse conversion rule for value accesses. 831 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 832 public: 833 using OpConversionPattern::OpConversionPattern; 834 LogicalResult 835 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 Type resType = op.getType(); 838 Type eltType = resType.cast<ShapedType>().getElementType(); 839 StringRef name; 840 if (eltType.isF64()) 841 name = "sparseValuesF64"; 842 else if (eltType.isF32()) 843 name = "sparseValuesF32"; 844 else if (eltType.isInteger(64)) 845 name = "sparseValuesI64"; 846 else if (eltType.isInteger(32)) 847 name = "sparseValuesI32"; 848 else if (eltType.isInteger(16)) 849 name = "sparseValuesI16"; 850 else if (eltType.isInteger(8)) 851 name = "sparseValuesI8"; 852 else 853 return failure(); 854 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 855 /*emitCInterface=*/true); 856 return success(); 857 } 858 }; 859 860 /// Sparse conversion rule for tensor rematerialization. 861 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 862 public: 863 using OpConversionPattern::OpConversionPattern; 864 LogicalResult 865 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 866 ConversionPatternRewriter &rewriter) const override { 867 if (op.hasInserts()) { 868 // Finalize any pending insertions. 869 StringRef name = "endInsert"; 870 TypeRange noTp; 871 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); 872 } 873 rewriter.replaceOp(op, adaptor.getOperands()); 874 return success(); 875 } 876 }; 877 878 /// Sparse conversion rule for inserting in lexicographic index order. 879 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 880 public: 881 using OpConversionPattern::OpConversionPattern; 882 LogicalResult 883 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 884 ConversionPatternRewriter &rewriter) const override { 885 Type srcType = op.tensor().getType(); 886 Type eltType = srcType.cast<ShapedType>().getElementType(); 887 StringRef name; 888 if (eltType.isF64()) 889 name = "lexInsertF64"; 890 else if (eltType.isF32()) 891 name = "lexInsertF32"; 892 else if (eltType.isInteger(64)) 893 name = "lexInsertI64"; 894 else if (eltType.isInteger(32)) 895 name = "lexInsertI32"; 896 else if (eltType.isInteger(16)) 897 name = "lexInsertI16"; 898 else if (eltType.isInteger(8)) 899 name = "lexInsertI8"; 900 else 901 llvm_unreachable("Unknown element type"); 902 TypeRange noTp; 903 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 904 /*emitCInterface=*/true); 905 return success(); 906 } 907 }; 908 909 } // namespace 910 911 //===----------------------------------------------------------------------===// 912 // Public method for populating conversion rules. 913 //===----------------------------------------------------------------------===// 914 915 /// Populates the given patterns list with conversion rules required for 916 /// the sparsification of linear algebra operations. 917 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 918 RewritePatternSet &patterns) { 919 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 920 SparseCastConverter, SparseTensorNewConverter, 921 SparseTensorInitConverter, SparseTensorConvertConverter, 922 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 923 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 924 SparseTensorLoadConverter, SparseTensorLexInsertConverter>( 925 typeConverter, patterns.getContext()); 926 } 927