1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 /// New tensor storage action. Keep these values consistent with 33 /// the sparse runtime support library. 34 enum Action : uint32_t { 35 kEmpty = 0, 36 kFromFile = 1, 37 kFromCOO = 2, 38 kEmptyCOO = 3, 39 kToCOO = 4, 40 kToIter = 5 41 }; 42 43 //===----------------------------------------------------------------------===// 44 // Helper methods. 45 //===----------------------------------------------------------------------===// 46 47 /// Returns internal type encoding for primary storage. Keep these 48 /// values consistent with the sparse runtime support library. 49 static uint32_t getPrimaryTypeEncoding(Type tp) { 50 if (tp.isF64()) 51 return 1; 52 if (tp.isF32()) 53 return 2; 54 if (tp.isInteger(64)) 55 return 3; 56 if (tp.isInteger(32)) 57 return 4; 58 if (tp.isInteger(16)) 59 return 5; 60 if (tp.isInteger(8)) 61 return 6; 62 return 0; 63 } 64 65 /// Returns internal type encoding for overhead storage. Keep these 66 /// values consistent with the sparse runtime support library. 67 static uint32_t getOverheadTypeEncoding(unsigned width) { 68 switch (width) { 69 default: 70 return 1; 71 case 32: 72 return 2; 73 case 16: 74 return 3; 75 case 8: 76 return 4; 77 } 78 } 79 80 /// Returns internal dimension level type encoding. Keep these 81 /// values consistent with the sparse runtime support library. 82 static uint32_t 83 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 84 switch (dlt) { 85 case SparseTensorEncodingAttr::DimLevelType::Dense: 86 return 0; 87 case SparseTensorEncodingAttr::DimLevelType::Compressed: 88 return 1; 89 case SparseTensorEncodingAttr::DimLevelType::Singleton: 90 return 2; 91 } 92 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 93 } 94 95 /// Generates a constant zero of the given type. 96 inline static Value constantZero(ConversionPatternRewriter &rewriter, 97 Location loc, Type t) { 98 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 99 } 100 101 /// Generates a constant of `index` type. 102 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 103 Location loc, int64_t i) { 104 return rewriter.create<arith::ConstantIndexOp>(loc, i); 105 } 106 107 /// Generates a constant of `i32` type. 108 inline static Value constantI32(ConversionPatternRewriter &rewriter, 109 Location loc, int32_t i) { 110 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 111 } 112 113 /// Generates a constant of `i8` type. 114 inline static Value constantI8(ConversionPatternRewriter &rewriter, 115 Location loc, int8_t i) { 116 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 117 } 118 119 /// Returns a function reference (first hit also inserts into module). Sets 120 /// the "_emit_c_interface" on the function declaration when requested, 121 /// so that LLVM lowering generates a wrapper function that takes care 122 /// of ABI complications with passing in and returning MemRefs to C functions. 123 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 124 TypeRange resultType, ValueRange operands, 125 bool emitCInterface = false) { 126 MLIRContext *context = op->getContext(); 127 auto module = op->getParentOfType<ModuleOp>(); 128 auto result = SymbolRefAttr::get(context, name); 129 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 130 if (!func) { 131 OpBuilder moduleBuilder(module.getBodyRegion()); 132 func = moduleBuilder.create<FuncOp>( 133 op->getLoc(), name, 134 FunctionType::get(context, operands.getTypes(), resultType)); 135 func.setPrivate(); 136 if (emitCInterface) 137 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 138 } 139 return result; 140 } 141 142 /// Generates dimension size call. 143 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 144 SparseTensorEncodingAttr &enc, Value src, 145 int64_t idx) { 146 // Permute the index according to an optional dimension ordering. 147 if (AffineMap p = enc.getDimOrdering()) 148 idx = p.getPermutedPosition(idx); 149 // Generate the call. 150 Location loc = op->getLoc(); 151 StringRef name = "sparseDimSize"; 152 SmallVector<Value, 2> params; 153 params.push_back(src); 154 params.push_back(constantIndex(rewriter, loc, idx)); 155 Type iTp = rewriter.getIndexType(); 156 auto fn = getFunc(op, name, iTp, params); 157 return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0); 158 } 159 160 /// Generates a call into the "swiss army knife" method of the sparse runtime 161 /// support library for materializing sparse tensors into the computation. 162 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 163 ArrayRef<Value> params) { 164 Location loc = op->getLoc(); 165 StringRef name = "newSparseTensor"; 166 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 167 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 168 auto call = rewriter.create<CallOp>(loc, pTp, fn, params); 169 return call.getResult(0); 170 } 171 172 /// Populates given sizes array from type. 173 static void sizesFromType(ConversionPatternRewriter &rewriter, 174 SmallVector<Value, 4> &sizes, Location loc, 175 ShapedType stp) { 176 auto shape = stp.getShape(); 177 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 178 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 179 sizes.push_back(constantIndex(rewriter, loc, s)); 180 } 181 } 182 183 /// Populates given sizes array from source. 184 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 185 SmallVector<Value, 4> &sizes, Location loc, 186 Value src) { 187 ShapedType stp = src.getType().cast<ShapedType>(); 188 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 189 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 190 } 191 192 /// Populates given sizes array from type (for static sizes) and from 193 /// an already converted into opague pointer source (for dynamic sizes). 194 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 195 SmallVector<Value, 4> &sizes, Operation *op, 196 SparseTensorEncodingAttr &enc, ShapedType stp, 197 Value src) { 198 auto shape = stp.getShape(); 199 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 200 if (shape[i] == ShapedType::kDynamicSize) 201 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 202 else 203 sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); 204 } 205 206 /// Generates an uninitialized temporary buffer of the given size and 207 /// type, but returns it as type `memref<? x $tp>` (rather than as type 208 /// `memref<$sz x $tp>`). 209 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 210 unsigned sz, Type tp) { 211 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 212 Value a = constantIndex(rewriter, loc, sz); 213 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 214 } 215 216 /// Generates an uninitialized temporary buffer with room for one value 217 /// of the given type, and returns the `memref<$tp>`. 218 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 219 Type tp) { 220 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 221 } 222 223 /// Generates a temporary buffer of the given type and given contents. 224 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 225 ArrayRef<Value> values) { 226 unsigned sz = values.size(); 227 assert(sz >= 1); 228 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 229 for (unsigned i = 0; i < sz; i++) { 230 Value idx = constantIndex(rewriter, loc, i); 231 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 232 } 233 return buffer; 234 } 235 236 /// Populates parameters required to call the "swiss army knife" method of the 237 /// sparse runtime support library for materializing sparse tensors into the 238 /// computation. 239 static void newParams(ConversionPatternRewriter &rewriter, 240 SmallVector<Value, 8> ¶ms, Operation *op, 241 SparseTensorEncodingAttr &enc, uint32_t action, 242 ValueRange szs, Value ptr = Value()) { 243 Location loc = op->getLoc(); 244 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 245 unsigned sz = dlt.size(); 246 // Sparsity annotations. 247 SmallVector<Value, 4> attrs; 248 for (unsigned i = 0; i < sz; i++) 249 attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i]))); 250 params.push_back(genBuffer(rewriter, loc, attrs)); 251 // Dimension sizes array of the enveloping tensor. Useful for either 252 // verification of external data, or for construction of internal data. 253 SmallVector<Value, 4> sizes; 254 for (Value s : szs) 255 sizes.push_back(s); 256 params.push_back(genBuffer(rewriter, loc, sizes)); 257 // Dimension order permutation array. This is the "identity" permutation by 258 // default, or otherwise the "reverse" permutation of a given ordering, so 259 // that indices can be mapped quickly to the right position. 260 SmallVector<Value, 4> rev(sz); 261 if (AffineMap p = enc.getDimOrdering()) { 262 for (unsigned i = 0; i < sz; i++) 263 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 264 } else { 265 for (unsigned i = 0; i < sz; i++) 266 rev[i] = constantIndex(rewriter, loc, i); 267 } 268 params.push_back(genBuffer(rewriter, loc, rev)); 269 // Secondary and primary types encoding. 270 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 271 uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 272 uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 273 uint32_t primary = getPrimaryTypeEncoding(resType.getElementType()); 274 assert(primary); 275 params.push_back(constantI32(rewriter, loc, secPtr)); 276 params.push_back(constantI32(rewriter, loc, secInd)); 277 params.push_back(constantI32(rewriter, loc, primary)); 278 // User action and pointer. 279 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 280 if (!ptr) 281 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 282 params.push_back(constantI32(rewriter, loc, action)); 283 params.push_back(ptr); 284 } 285 286 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 287 /// For floating types, we use the "unordered" comparator (i.e., returns 288 /// true if `v` is NaN). 289 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 290 Value v) { 291 Type t = v.getType(); 292 Value zero = constantZero(rewriter, loc, t); 293 if (t.isa<FloatType>()) 294 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 295 zero); 296 if (t.isIntOrIndex()) 297 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 298 zero); 299 llvm_unreachable("Unknown element type"); 300 } 301 302 /// Generates the code to read the value from tensor[ivs], and conditionally 303 /// stores the indices ivs to the memory in ind. The generated code looks like 304 /// the following and the insertion point after this routine is inside the 305 /// if-then branch behind the assignment to ind. This is to ensure that the 306 /// addEltX call generated after is inside the if-then branch. 307 /// if (tensor[ivs]!=0) { 308 /// ind = ivs 309 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 310 Location loc, Value tensor, Value ind, 311 ValueRange ivs) { 312 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 313 Value cond = genIsNonzero(rewriter, loc, val); 314 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 315 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 316 unsigned i = 0; 317 for (auto iv : ivs) { 318 Value idx = constantIndex(rewriter, loc, i++); 319 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 320 } 321 return val; 322 } 323 324 /// Generates a call that adds one element to a coordinate scheme. 325 /// In particular, this generates code like the following: 326 /// val = a[i1,..,ik]; 327 /// if val != 0 328 /// t->add(val, [i1,..,ik], [p1,..,pk]); 329 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 330 Type eltType, Value ptr, Value val, Value ind, 331 Value perm) { 332 Location loc = op->getLoc(); 333 StringRef name; 334 if (eltType.isF64()) 335 name = "addEltF64"; 336 else if (eltType.isF32()) 337 name = "addEltF32"; 338 else if (eltType.isInteger(64)) 339 name = "addEltI64"; 340 else if (eltType.isInteger(32)) 341 name = "addEltI32"; 342 else if (eltType.isInteger(16)) 343 name = "addEltI16"; 344 else if (eltType.isInteger(8)) 345 name = "addEltI8"; 346 else 347 llvm_unreachable("Unknown element type"); 348 SmallVector<Value, 8> params; 349 params.push_back(ptr); 350 params.push_back(val); 351 params.push_back(ind); 352 params.push_back(perm); 353 Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); 354 auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); 355 rewriter.create<CallOp>(loc, pTp, fn, params); 356 } 357 358 /// Generates a call to `iter->getNext()`. If there is a next element, 359 /// then it is copied into the out-parameters `ind` and `elemPtr`, 360 /// and the return value is true. If there isn't a next element, then 361 /// the memory for `iter` is freed and the return value is false. 362 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 363 Value iter, Value ind, Value elemPtr) { 364 Location loc = op->getLoc(); 365 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 366 StringRef name; 367 if (elemTp.isF64()) 368 name = "getNextF64"; 369 else if (elemTp.isF32()) 370 name = "getNextF32"; 371 else if (elemTp.isInteger(64)) 372 name = "getNextI64"; 373 else if (elemTp.isInteger(32)) 374 name = "getNextI32"; 375 else if (elemTp.isInteger(16)) 376 name = "getNextI16"; 377 else if (elemTp.isInteger(8)) 378 name = "getNextI8"; 379 else 380 llvm_unreachable("Unknown element type"); 381 SmallVector<Value, 3> params; 382 params.push_back(iter); 383 params.push_back(ind); 384 params.push_back(elemPtr); 385 Type i1 = rewriter.getI1Type(); 386 auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); 387 auto call = rewriter.create<CallOp>(loc, i1, fn, params); 388 return call.getResult(0); 389 } 390 391 /// If the tensor is a sparse constant, generates and returns the pair of 392 /// the constants for the indices and the values. 393 static Optional<std::pair<Value, Value>> 394 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 395 Value tensor) { 396 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 397 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 398 DenseElementsAttr indicesAttr = attr.getIndices(); 399 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 400 DenseElementsAttr valuesAttr = attr.getValues(); 401 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 402 return std::make_pair(indices, values); 403 } 404 } 405 return {}; 406 } 407 408 /// Generates the code to copy the index at indices[ivs] to ind, and return 409 /// the value at value[ivs]. 410 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 411 Location loc, Value indices, 412 Value values, Value ind, ValueRange ivs, 413 unsigned rank) { 414 for (unsigned i = 0; i < rank; i++) { 415 Value idx = constantIndex(rewriter, loc, i); 416 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 417 ValueRange{ivs[0], idx}); 418 val = 419 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 420 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 421 } 422 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 423 } 424 425 /// Generates code to allocate a tensor of the given type, and zero 426 /// initialize it. If the tensor type has any dynamic sizes, then the 427 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 428 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 429 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 430 RankedTensorType tensorTp, ValueRange sizes) { 431 Type elemTp = tensorTp.getElementType(); 432 auto shape = tensorTp.getShape(); 433 auto memTp = MemRefType::get(shape, elemTp); 434 SmallVector<Value> dynamicSizes; 435 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 436 if (shape[i] == ShapedType::kDynamicSize) 437 dynamicSizes.push_back(sizes[i]); 438 } 439 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 440 Value zero = constantZero(rewriter, loc, elemTp); 441 rewriter.create<linalg::FillOp>(loc, zero, mem).result(); 442 return mem; 443 } 444 445 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 446 /// the tensor created by allocDenseTensor(). The `rank` is the rank 447 /// of the `tensor` and the length of `ind`. 448 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 449 Location loc, Value elemPtr, 450 Value tensor, unsigned rank, 451 Value ind) { 452 SmallVector<Value, 4> ivs; 453 ivs.reserve(rank); 454 for (unsigned i = 0; i < rank; i++) { 455 Value idx = constantIndex(rewriter, loc, i); 456 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 457 } 458 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 459 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 460 } 461 462 //===----------------------------------------------------------------------===// 463 // Conversion rules. 464 //===----------------------------------------------------------------------===// 465 466 /// Sparse conversion rule for returns. 467 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 468 public: 469 using OpConversionPattern::OpConversionPattern; 470 LogicalResult 471 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 472 ConversionPatternRewriter &rewriter) const override { 473 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 474 return success(); 475 } 476 }; 477 478 /// Sparse conversion rule for dimension accesses. 479 class SparseTensorToDimSizeConverter 480 : public OpConversionPattern<tensor::DimOp> { 481 public: 482 using OpConversionPattern::OpConversionPattern; 483 LogicalResult 484 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 485 ConversionPatternRewriter &rewriter) const override { 486 // Only rewrite annotated DimOp with constant index. 487 auto enc = getSparseTensorEncoding(op.source().getType()); 488 if (!enc) 489 return failure(); 490 Optional<int64_t> index = op.getConstantIndex(); 491 if (!index.hasValue()) 492 return failure(); 493 // Generate the call. 494 Value src = adaptor.getOperands()[0]; 495 int64_t idx = index.getValue(); 496 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 497 return success(); 498 } 499 }; 500 501 /// Sparse conversion rule for trivial tensor casts. 502 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 503 using OpConversionPattern::OpConversionPattern; 504 LogicalResult 505 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 506 ConversionPatternRewriter &rewriter) const override { 507 // Only rewrite identically annotated source/dest. 508 auto encDst = getSparseTensorEncoding(op.getType()); 509 auto encSrc = getSparseTensorEncoding(op.source().getType()); 510 if (!encDst || encDst != encSrc) 511 return failure(); 512 rewriter.replaceOp(op, adaptor.getOperands()); 513 return success(); 514 } 515 }; 516 517 /// Sparse conversion rule for the new operator. 518 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 519 using OpConversionPattern::OpConversionPattern; 520 LogicalResult 521 matchAndRewrite(NewOp op, OpAdaptor adaptor, 522 ConversionPatternRewriter &rewriter) const override { 523 Type resType = op.getType(); 524 auto enc = getSparseTensorEncoding(resType); 525 if (!enc) 526 return failure(); 527 // Generate the call to construct tensor from ptr. The sizes are 528 // inferred from the result type of the new operator. 529 SmallVector<Value, 4> sizes; 530 SmallVector<Value, 8> params; 531 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 532 Value ptr = adaptor.getOperands()[0]; 533 newParams(rewriter, params, op, enc, kFromFile, sizes, ptr); 534 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 535 return success(); 536 } 537 }; 538 539 /// Sparse conversion rule for the init operator. 540 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 541 using OpConversionPattern::OpConversionPattern; 542 LogicalResult 543 matchAndRewrite(InitOp op, OpAdaptor adaptor, 544 ConversionPatternRewriter &rewriter) const override { 545 Type resType = op.getType(); 546 auto enc = getSparseTensorEncoding(resType); 547 if (!enc) 548 return failure(); 549 // Generate the call to construct empty tensor. The sizes are 550 // explicitly defined by the arguments to the init operator. 551 SmallVector<Value, 8> params; 552 newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands()); 553 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 554 return success(); 555 } 556 }; 557 558 /// Sparse conversion rule for the convert operator. 559 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 560 using OpConversionPattern::OpConversionPattern; 561 LogicalResult 562 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 563 ConversionPatternRewriter &rewriter) const override { 564 Location loc = op->getLoc(); 565 Type resType = op.getType(); 566 Type srcType = op.source().getType(); 567 auto encDst = getSparseTensorEncoding(resType); 568 auto encSrc = getSparseTensorEncoding(srcType); 569 Value src = adaptor.getOperands()[0]; 570 if (encDst && encSrc) { 571 // This is a sparse => sparse conversion, which is handled as follows: 572 // t = src->toCOO(); ; src to COO in dst order 573 // dst = newSparseTensor(t) 574 // Using the coordinate scheme as an intermediate does not always 575 // yield the fastest conversion but avoids the need for a full 576 // O(N^2) conversion matrix. 577 if (encDst == encSrc) { 578 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 579 return success(); 580 } 581 SmallVector<Value, 4> sizes; 582 SmallVector<Value, 8> params; 583 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 584 src); 585 // Set up encoding with right mix of src and dst so that the two 586 // method calls can share most parameters, while still providing 587 // the correct sparsity information to either of them. 588 auto enc = SparseTensorEncodingAttr::get( 589 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 590 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 591 newParams(rewriter, params, op, enc, kToCOO, sizes, src); 592 Value coo = genNewCall(rewriter, op, params); 593 params[3] = constantI32( 594 rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth())); 595 params[4] = constantI32( 596 rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth())); 597 params[6] = constantI32(rewriter, loc, kFromCOO); 598 params[7] = coo; 599 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 600 return success(); 601 } 602 if (!encDst && encSrc) { 603 // This is sparse => dense conversion, which is handled as follows: 604 // dst = new Tensor(0); 605 // iter = src->toCOO(); 606 // iter->startIterator(); 607 // while (elem = iter->getNext()) { 608 // dst[elem.indices] = elem.value; 609 // } 610 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 611 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 612 unsigned rank = dstTensorTp.getRank(); 613 Type elemTp = dstTensorTp.getElementType(); 614 // Fabricate a no-permutation encoding for newParams(). 615 // The pointer/index types must be those of `src`. 616 // The dimLevelTypes aren't actually used by kToIter. 617 encDst = SparseTensorEncodingAttr::get( 618 op->getContext(), 619 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 620 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 621 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 622 SmallVector<Value, 4> sizes; 623 SmallVector<Value, 8> params; 624 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 625 newParams(rewriter, params, op, encDst, kToIter, sizes, src); 626 Value iter = genNewCall(rewriter, op, params); 627 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 628 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 629 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 630 SmallVector<Value> noArgs; 631 SmallVector<Type> noTypes; 632 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 633 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 634 rewriter.setInsertionPointToEnd(before); 635 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 636 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 637 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 638 rewriter.setInsertionPointToStart(after); 639 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 640 rewriter.create<scf::YieldOp>(loc); 641 rewriter.setInsertionPointAfter(whileOp); 642 rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, dst); 643 return success(); 644 } 645 if (!encDst && !encSrc) { 646 // dense => dense 647 return failure(); 648 } 649 // This is a dense => sparse conversion or a sparse constant in COO => 650 // sparse conversion, which is handled as follows: 651 // t = newSparseCOO() 652 // ...code to fill the COO tensor t... 653 // s = newSparseTensor(t) 654 // 655 // To fill the COO tensor from a dense tensor: 656 // for i1 in dim1 657 // .. 658 // for ik in dimk 659 // val = a[i1,..,ik] 660 // if val != 0 661 // t->add(val, [i1,..,ik], [p1,..,pk]) 662 // 663 // To fill the COO tensor from a sparse constant in COO format: 664 // for i in range(NNZ) 665 // val = values[i] 666 // [i1,..,ik] = indices[i] 667 // t->add(val, [i1,..,ik], [p1,..,pk]) 668 // 669 // Note that the dense tensor traversal code is actually implemented 670 // using MLIR IR to avoid having to expose too much low-level 671 // memref traversal details to the runtime support library. 672 // Also note that the code below only generates the "new" ops and 673 // the loop-nest per se; whereas the entire body of the innermost 674 // loop is generated by genAddElt(). 675 ShapedType stp = resType.cast<ShapedType>(); 676 unsigned rank = stp.getRank(); 677 SmallVector<Value, 4> sizes; 678 SmallVector<Value, 8> params; 679 sizesFromSrc(rewriter, sizes, loc, src); 680 newParams(rewriter, params, op, encDst, kEmptyCOO, sizes); 681 Value ptr = genNewCall(rewriter, op, params); 682 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 683 Value perm = params[2]; 684 SmallVector<Value> lo; 685 SmallVector<Value> hi; 686 SmallVector<Value> st; 687 Value zero = constantIndex(rewriter, loc, 0); 688 Value one = constantIndex(rewriter, loc, 1); 689 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 690 bool isCOOConstant = indicesValues.hasValue(); 691 Value indices; 692 Value values; 693 if (isCOOConstant) { 694 indices = indicesValues->first; 695 values = indicesValues->second; 696 lo.push_back(zero); 697 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 698 st.push_back(one); 699 } else { 700 for (unsigned i = 0; i < rank; i++) { 701 lo.push_back(zero); 702 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 703 st.push_back(one); 704 } 705 } 706 Type eltType = stp.getElementType(); 707 scf::buildLoopNest( 708 rewriter, op.getLoc(), lo, hi, st, {}, 709 [&](OpBuilder &builder, Location loc, ValueRange ivs, 710 ValueRange args) -> scf::ValueVector { 711 Value val; 712 if (isCOOConstant) 713 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 714 ivs, rank); 715 else 716 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 717 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 718 return {}; 719 }); 720 // Final call to construct sparse tensor storage. 721 params[6] = constantI32(rewriter, loc, kFromCOO); 722 params[7] = ptr; 723 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 724 return success(); 725 } 726 }; 727 728 /// Sparse conversion rule for the release operator. 729 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 730 public: 731 using OpConversionPattern::OpConversionPattern; 732 LogicalResult 733 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 734 ConversionPatternRewriter &rewriter) const override { 735 StringRef name = "delSparseTensor"; 736 TypeRange none; 737 auto fn = getFunc(op, name, none, adaptor.getOperands()); 738 rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands()); 739 rewriter.eraseOp(op); 740 return success(); 741 } 742 }; 743 744 /// Sparse conversion rule for pointer accesses. 745 class SparseTensorToPointersConverter 746 : public OpConversionPattern<ToPointersOp> { 747 public: 748 using OpConversionPattern::OpConversionPattern; 749 LogicalResult 750 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 751 ConversionPatternRewriter &rewriter) const override { 752 Type resType = op.getType(); 753 Type eltType = resType.cast<ShapedType>().getElementType(); 754 StringRef name; 755 if (eltType.isIndex()) 756 name = "sparsePointers"; 757 else if (eltType.isInteger(64)) 758 name = "sparsePointers64"; 759 else if (eltType.isInteger(32)) 760 name = "sparsePointers32"; 761 else if (eltType.isInteger(16)) 762 name = "sparsePointers16"; 763 else if (eltType.isInteger(8)) 764 name = "sparsePointers8"; 765 else 766 return failure(); 767 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 768 /*emitCInterface=*/true); 769 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 770 return success(); 771 } 772 }; 773 774 /// Sparse conversion rule for index accesses. 775 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 776 public: 777 using OpConversionPattern::OpConversionPattern; 778 LogicalResult 779 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 780 ConversionPatternRewriter &rewriter) const override { 781 Type resType = op.getType(); 782 Type eltType = resType.cast<ShapedType>().getElementType(); 783 StringRef name; 784 if (eltType.isIndex()) 785 name = "sparseIndices"; 786 else if (eltType.isInteger(64)) 787 name = "sparseIndices64"; 788 else if (eltType.isInteger(32)) 789 name = "sparseIndices32"; 790 else if (eltType.isInteger(16)) 791 name = "sparseIndices16"; 792 else if (eltType.isInteger(8)) 793 name = "sparseIndices8"; 794 else 795 return failure(); 796 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 797 /*emitCInterface=*/true); 798 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 799 return success(); 800 } 801 }; 802 803 /// Sparse conversion rule for value accesses. 804 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 805 public: 806 using OpConversionPattern::OpConversionPattern; 807 LogicalResult 808 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 809 ConversionPatternRewriter &rewriter) const override { 810 Type resType = op.getType(); 811 Type eltType = resType.cast<ShapedType>().getElementType(); 812 StringRef name; 813 if (eltType.isF64()) 814 name = "sparseValuesF64"; 815 else if (eltType.isF32()) 816 name = "sparseValuesF32"; 817 else if (eltType.isInteger(64)) 818 name = "sparseValuesI64"; 819 else if (eltType.isInteger(32)) 820 name = "sparseValuesI32"; 821 else if (eltType.isInteger(16)) 822 name = "sparseValuesI16"; 823 else if (eltType.isInteger(8)) 824 name = "sparseValuesI8"; 825 else 826 return failure(); 827 auto fn = getFunc(op, name, resType, adaptor.getOperands(), 828 /*emitCInterface=*/true); 829 rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands()); 830 return success(); 831 } 832 }; 833 834 /// Sparse conversion rule for tensor reconstruction. 835 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 836 public: 837 using OpConversionPattern::OpConversionPattern; 838 LogicalResult 839 // Simply fold the operator into the pointer to the sparse storage scheme. 840 matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, 841 ConversionPatternRewriter &rewriter) const override { 842 // Check that all arguments of the tensor reconstruction operators are calls 843 // into the support library that query exactly the same opaque pointer. 844 Value ptr; 845 for (Value op : adaptor.getOperands()) { 846 if (auto call = op.getDefiningOp<CallOp>()) { 847 Value arg = call.getOperand(0); 848 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 849 return failure(); 850 if (!ptr) 851 ptr = arg; 852 else if (arg != ptr) 853 return failure(); 854 } 855 } 856 // If a single opaque pointer is found, perform the folding. 857 if (!ptr) 858 return failure(); 859 rewriter.replaceOp(op, ptr); 860 return success(); 861 } 862 }; 863 864 } // namespace 865 866 //===----------------------------------------------------------------------===// 867 // Public method for populating conversion rules. 868 //===----------------------------------------------------------------------===// 869 870 /// Populates the given patterns list with conversion rules required for 871 /// the sparsification of linear algebra operations. 872 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 873 RewritePatternSet &patterns) { 874 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 875 SparseCastConverter, SparseTensorNewConverter, 876 SparseTensorInitConverter, SparseTensorConvertConverter, 877 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 878 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 879 SparseTensorToTensorConverter>(typeConverter, 880 patterns.getContext()); 881 } 882