1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 30 using namespace mlir; 31 using namespace mlir::sparse_tensor; 32 33 namespace { 34 35 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 36 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 37 enum class EmitCInterface : bool { Off = false, On = true }; 38 39 //===----------------------------------------------------------------------===// 40 // Helper methods. 41 //===----------------------------------------------------------------------===// 42 43 /// Returns the equivalent of `void*` for opaque arguments to the 44 /// execution engine. 45 static Type getOpaquePointerType(PatternRewriter &rewriter) { 46 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 47 } 48 49 /// Returns a function reference (first hit also inserts into module). Sets 50 /// the "_emit_c_interface" on the function declaration when requested, 51 /// so that LLVM lowering generates a wrapper function that takes care 52 /// of ABI complications with passing in and returning MemRefs to C functions. 53 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 54 TypeRange resultType, ValueRange operands, 55 EmitCInterface emitCInterface) { 56 MLIRContext *context = op->getContext(); 57 auto module = op->getParentOfType<ModuleOp>(); 58 auto result = SymbolRefAttr::get(context, name); 59 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 60 if (!func) { 61 OpBuilder moduleBuilder(module.getBodyRegion()); 62 func = moduleBuilder.create<FuncOp>( 63 op->getLoc(), name, 64 FunctionType::get(context, operands.getTypes(), resultType)); 65 func.setPrivate(); 66 if (static_cast<bool>(emitCInterface)) 67 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 68 } 69 return result; 70 } 71 72 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 73 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 74 TypeRange resultType, ValueRange operands, 75 EmitCInterface emitCInterface) { 76 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 77 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 78 } 79 80 /// Replaces the `op` with a `CallOp` to the function reference returned 81 /// by `getFunc()`. 82 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 83 StringRef name, TypeRange resultType, 84 ValueRange operands, 85 EmitCInterface emitCInterface) { 86 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 87 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 88 } 89 90 /// Generates dimension size call. 91 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 92 SparseTensorEncodingAttr &enc, Value src, 93 int64_t idx) { 94 // Permute the index according to an optional dimension ordering. 95 if (AffineMap p = enc.getDimOrdering()) 96 idx = p.getPermutedPosition(idx); 97 // Generate the call. 98 StringRef name = "sparseDimSize"; 99 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 100 Type iTp = rewriter.getIndexType(); 101 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 102 .getResult(0); 103 } 104 105 /// Generates a call into the "swiss army knife" method of the sparse runtime 106 /// support library for materializing sparse tensors into the computation. 107 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 108 ArrayRef<Value> params) { 109 StringRef name = "newSparseTensor"; 110 Type pTp = getOpaquePointerType(rewriter); 111 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 112 .getResult(0); 113 } 114 115 /// Populates given sizes array from type. 116 static void sizesFromType(ConversionPatternRewriter &rewriter, 117 SmallVector<Value, 4> &sizes, Location loc, 118 ShapedType stp) { 119 auto shape = stp.getShape(); 120 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 121 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 122 sizes.push_back(constantIndex(rewriter, loc, s)); 123 } 124 } 125 126 /// Populates given sizes array from source. 127 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 128 SmallVector<Value, 4> &sizes, Location loc, 129 Value src) { 130 unsigned rank = src.getType().cast<ShapedType>().getRank(); 131 for (unsigned i = 0; i < rank; i++) 132 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 133 } 134 135 /// Populates given sizes array from type (for static sizes) and from 136 /// an already converted into opague pointer source (for dynamic sizes). 137 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 138 SmallVector<Value, 4> &sizes, Operation *op, 139 SparseTensorEncodingAttr &enc, ShapedType stp, 140 Value src) { 141 Location loc = op->getLoc(); 142 auto shape = stp.getShape(); 143 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 144 if (shape[i] == ShapedType::kDynamicSize) 145 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 146 else 147 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 148 } 149 150 /// Generates an uninitialized temporary buffer of the given size and 151 /// type, but returns it as type `memref<? x $tp>` (rather than as type 152 /// `memref<$sz x $tp>`). 153 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 154 Value sz, Type tp) { 155 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 156 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 157 } 158 159 /// Generates an uninitialized temporary buffer of the given size and 160 /// type, but returns it as type `memref<? x $tp>` (rather than as type 161 /// `memref<$sz x $tp>`). 162 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 163 unsigned sz, Type tp) { 164 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 165 } 166 167 /// Generates an uninitialized temporary buffer with room for one value 168 /// of the given type, and returns the `memref<$tp>`. 169 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 170 Type tp) { 171 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 172 } 173 174 /// Generates a temporary buffer of the given type and given contents. 175 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 176 ArrayRef<Value> values) { 177 unsigned sz = values.size(); 178 assert(sz >= 1); 179 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 180 for (unsigned i = 0; i < sz; i++) { 181 Value idx = constantIndex(rewriter, loc, i); 182 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 183 } 184 return buffer; 185 } 186 187 /// Populates parameters required to call the "swiss army knife" method of the 188 /// sparse runtime support library for materializing sparse tensors into the 189 /// computation. 190 static void newParams(ConversionPatternRewriter &rewriter, 191 SmallVector<Value, 8> ¶ms, Operation *op, 192 SparseTensorEncodingAttr &enc, Action action, 193 ValueRange szs, Value ptr = Value()) { 194 Location loc = op->getLoc(); 195 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 196 unsigned sz = dlt.size(); 197 // Sparsity annotations. 198 SmallVector<Value, 4> attrs; 199 for (unsigned i = 0; i < sz; i++) 200 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 201 params.push_back(genBuffer(rewriter, loc, attrs)); 202 // Dimension sizes array of the enveloping tensor. Useful for either 203 // verification of external data, or for construction of internal data. 204 SmallVector<Value, 4> sizes; 205 for (Value s : szs) 206 sizes.push_back(s); 207 params.push_back(genBuffer(rewriter, loc, sizes)); 208 // Dimension order permutation array. This is the "identity" permutation by 209 // default, or otherwise the "reverse" permutation of a given ordering, so 210 // that indices can be mapped quickly to the right position. 211 SmallVector<Value, 4> rev(sz); 212 if (AffineMap p = enc.getDimOrdering()) { 213 for (unsigned i = 0; i < sz; i++) 214 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 215 } else { 216 for (unsigned i = 0; i < sz; i++) 217 rev[i] = constantIndex(rewriter, loc, i); 218 } 219 params.push_back(genBuffer(rewriter, loc, rev)); 220 // Secondary and primary types encoding. 221 Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType(); 222 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 223 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 224 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 225 // User action. 226 params.push_back(constantAction(rewriter, loc, action)); 227 // Payload pointer. 228 if (!ptr) 229 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 230 params.push_back(ptr); 231 } 232 233 /// Generates the code to read the value from tensor[ivs], and conditionally 234 /// stores the indices ivs to the memory in ind. The generated code looks like 235 /// the following and the insertion point after this routine is inside the 236 /// if-then branch behind the assignment to ind. This is to ensure that the 237 /// addEltX call generated after is inside the if-then branch. 238 /// if (tensor[ivs]!=0) { 239 /// ind = ivs 240 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 241 Location loc, Value tensor, Value ind, 242 ValueRange ivs) { 243 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 244 Value cond = genIsNonzero(rewriter, loc, val); 245 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 246 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 247 unsigned i = 0; 248 for (auto iv : ivs) { 249 Value idx = constantIndex(rewriter, loc, i++); 250 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 251 } 252 return val; 253 } 254 255 /// Generates a call that adds one element to a coordinate scheme. 256 /// In particular, this generates code like the following: 257 /// val = a[i1,..,ik]; 258 /// if val != 0 259 /// t->add(val, [i1,..,ik], [p1,..,pk]); 260 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 261 Type eltType, Value ptr, Value val, Value ind, 262 Value perm) { 263 StringRef name; 264 if (eltType.isF64()) 265 name = "addEltF64"; 266 else if (eltType.isF32()) 267 name = "addEltF32"; 268 else if (eltType.isInteger(64)) 269 name = "addEltI64"; 270 else if (eltType.isInteger(32)) 271 name = "addEltI32"; 272 else if (eltType.isInteger(16)) 273 name = "addEltI16"; 274 else if (eltType.isInteger(8)) 275 name = "addEltI8"; 276 else 277 llvm_unreachable("Unknown element type"); 278 SmallVector<Value, 4> params{ptr, val, ind, perm}; 279 Type pTp = getOpaquePointerType(rewriter); 280 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 281 } 282 283 /// Generates a call to `iter->getNext()`. If there is a next element, 284 /// then it is copied into the out-parameters `ind` and `elemPtr`, 285 /// and the return value is true. If there isn't a next element, then 286 /// the memory for `iter` is freed and the return value is false. 287 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 288 Value iter, Value ind, Value elemPtr) { 289 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 290 StringRef name; 291 if (elemTp.isF64()) 292 name = "getNextF64"; 293 else if (elemTp.isF32()) 294 name = "getNextF32"; 295 else if (elemTp.isInteger(64)) 296 name = "getNextI64"; 297 else if (elemTp.isInteger(32)) 298 name = "getNextI32"; 299 else if (elemTp.isInteger(16)) 300 name = "getNextI16"; 301 else if (elemTp.isInteger(8)) 302 name = "getNextI8"; 303 else 304 llvm_unreachable("Unknown element type"); 305 SmallVector<Value, 3> params{iter, ind, elemPtr}; 306 Type i1 = rewriter.getI1Type(); 307 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 308 .getResult(0); 309 } 310 311 /// If the tensor is a sparse constant, generates and returns the pair of 312 /// the constants for the indices and the values. 313 static Optional<std::pair<Value, Value>> 314 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 315 Value tensor) { 316 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 317 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 318 DenseElementsAttr indicesAttr = attr.getIndices(); 319 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 320 DenseElementsAttr valuesAttr = attr.getValues(); 321 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 322 return std::make_pair(indices, values); 323 } 324 } 325 return {}; 326 } 327 328 /// Generates the code to copy the index at indices[ivs] to ind, and return 329 /// the value at value[ivs]. 330 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 331 Location loc, Value indices, 332 Value values, Value ind, ValueRange ivs, 333 unsigned rank) { 334 for (unsigned i = 0; i < rank; i++) { 335 Value idx = constantIndex(rewriter, loc, i); 336 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 337 ValueRange{ivs[0], idx}); 338 val = 339 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 340 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 341 } 342 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 343 } 344 345 /// Generates code to allocate a tensor of the given type, and zero 346 /// initialize it. If the tensor type has any dynamic sizes, then the 347 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 348 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 349 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 350 RankedTensorType tensorTp, ValueRange sizes) { 351 Type elemTp = tensorTp.getElementType(); 352 auto shape = tensorTp.getShape(); 353 auto memTp = MemRefType::get(shape, elemTp); 354 SmallVector<Value> dynamicSizes; 355 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 356 if (shape[i] == ShapedType::kDynamicSize) 357 dynamicSizes.push_back(sizes[i]); 358 } 359 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 360 Value zero = constantZero(rewriter, loc, elemTp); 361 rewriter.create<linalg::FillOp>(loc, zero, mem); 362 return mem; 363 } 364 365 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 366 /// the tensor created by allocDenseTensor(). The `rank` is the rank 367 /// of the `tensor` and the length of `ind`. 368 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 369 Location loc, Value elemPtr, 370 Value tensor, unsigned rank, 371 Value ind) { 372 SmallVector<Value, 4> ivs; 373 ivs.reserve(rank); 374 for (unsigned i = 0; i < rank; i++) { 375 Value idx = constantIndex(rewriter, loc, i); 376 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 377 } 378 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 379 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // Conversion rules. 384 //===----------------------------------------------------------------------===// 385 386 /// Sparse conversion rule for returns. 387 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 388 public: 389 using OpConversionPattern::OpConversionPattern; 390 LogicalResult 391 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 394 return success(); 395 } 396 }; 397 398 /// Sparse conversion rule for dimension accesses. 399 class SparseTensorToDimSizeConverter 400 : public OpConversionPattern<tensor::DimOp> { 401 public: 402 using OpConversionPattern::OpConversionPattern; 403 LogicalResult 404 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 405 ConversionPatternRewriter &rewriter) const override { 406 // Only rewrite annotated DimOp with constant index. 407 auto enc = getSparseTensorEncoding(op.source().getType()); 408 if (!enc) 409 return failure(); 410 Optional<int64_t> index = op.getConstantIndex(); 411 if (!index.hasValue()) 412 return failure(); 413 // Generate the call. 414 Value src = adaptor.getOperands()[0]; 415 int64_t idx = index.getValue(); 416 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 417 return success(); 418 } 419 }; 420 421 /// Sparse conversion rule for trivial tensor casts. 422 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 423 using OpConversionPattern::OpConversionPattern; 424 LogicalResult 425 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 426 ConversionPatternRewriter &rewriter) const override { 427 // Only rewrite identically annotated source/dest. 428 auto encDst = getSparseTensorEncoding(op.getType()); 429 auto encSrc = getSparseTensorEncoding(op.source().getType()); 430 if (!encDst || encDst != encSrc) 431 return failure(); 432 rewriter.replaceOp(op, adaptor.getOperands()); 433 return success(); 434 } 435 }; 436 437 /// Sparse conversion rule for the new operator. 438 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 439 using OpConversionPattern::OpConversionPattern; 440 LogicalResult 441 matchAndRewrite(NewOp op, OpAdaptor adaptor, 442 ConversionPatternRewriter &rewriter) const override { 443 Type resType = op.getType(); 444 auto enc = getSparseTensorEncoding(resType); 445 if (!enc) 446 return failure(); 447 // Generate the call to construct tensor from ptr. The sizes are 448 // inferred from the result type of the new operator. 449 SmallVector<Value, 4> sizes; 450 SmallVector<Value, 8> params; 451 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 452 Value ptr = adaptor.getOperands()[0]; 453 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 454 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 455 return success(); 456 } 457 }; 458 459 /// Sparse conversion rule for the init operator. 460 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 461 using OpConversionPattern::OpConversionPattern; 462 LogicalResult 463 matchAndRewrite(InitOp op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 Type resType = op.getType(); 466 auto enc = getSparseTensorEncoding(resType); 467 if (!enc) 468 return failure(); 469 // Generate the call to construct empty tensor. The sizes are 470 // explicitly defined by the arguments to the init operator. 471 SmallVector<Value, 8> params; 472 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 473 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 474 return success(); 475 } 476 }; 477 478 /// Sparse conversion rule for the convert operator. 479 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 480 using OpConversionPattern::OpConversionPattern; 481 LogicalResult 482 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 483 ConversionPatternRewriter &rewriter) const override { 484 Location loc = op->getLoc(); 485 Type resType = op.getType(); 486 Type srcType = op.source().getType(); 487 auto encDst = getSparseTensorEncoding(resType); 488 auto encSrc = getSparseTensorEncoding(srcType); 489 Value src = adaptor.getOperands()[0]; 490 if (encDst && encSrc) { 491 // This is a sparse => sparse conversion, which is handled as follows: 492 // t = src->toCOO(); ; src to COO in dst order 493 // dst = newSparseTensor(t) 494 // Using the coordinate scheme as an intermediate does not always 495 // yield the fastest conversion but avoids the need for a full 496 // O(N^2) conversion matrix. 497 if (encDst == encSrc) { 498 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 499 return success(); 500 } 501 SmallVector<Value, 4> sizes; 502 SmallVector<Value, 8> params; 503 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 504 src); 505 // Set up encoding with right mix of src and dst so that the two 506 // method calls can share most parameters, while still providing 507 // the correct sparsity information to either of them. 508 auto enc = SparseTensorEncodingAttr::get( 509 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 510 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 511 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 512 Value coo = genNewCall(rewriter, op, params); 513 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 514 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 515 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 516 params[7] = coo; 517 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 518 return success(); 519 } 520 if (!encDst && encSrc) { 521 // This is sparse => dense conversion, which is handled as follows: 522 // dst = new Tensor(0); 523 // iter = src->toCOO(); 524 // iter->startIterator(); 525 // while (elem = iter->getNext()) { 526 // dst[elem.indices] = elem.value; 527 // } 528 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 529 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 530 unsigned rank = dstTensorTp.getRank(); 531 Type elemTp = dstTensorTp.getElementType(); 532 // Fabricate a no-permutation encoding for newParams(). 533 // The pointer/index types must be those of `src`. 534 // The dimLevelTypes aren't actually used by Action::kToIterator. 535 encDst = SparseTensorEncodingAttr::get( 536 op->getContext(), 537 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 538 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 539 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 540 SmallVector<Value, 4> sizes; 541 SmallVector<Value, 8> params; 542 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 543 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 544 Value iter = genNewCall(rewriter, op, params); 545 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 546 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 547 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 548 SmallVector<Value> noArgs; 549 SmallVector<Type> noTypes; 550 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 551 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 552 rewriter.setInsertionPointToEnd(before); 553 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 554 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 555 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 556 rewriter.setInsertionPointToStart(after); 557 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 558 rewriter.create<scf::YieldOp>(loc); 559 rewriter.setInsertionPointAfter(whileOp); 560 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 561 return success(); 562 } 563 if (!encDst && !encSrc) { 564 // dense => dense 565 return failure(); 566 } 567 // This is a dense => sparse conversion or a sparse constant in COO => 568 // sparse conversion, which is handled as follows: 569 // t = newSparseCOO() 570 // ...code to fill the COO tensor t... 571 // s = newSparseTensor(t) 572 // 573 // To fill the COO tensor from a dense tensor: 574 // for i1 in dim1 575 // .. 576 // for ik in dimk 577 // val = a[i1,..,ik] 578 // if val != 0 579 // t->add(val, [i1,..,ik], [p1,..,pk]) 580 // 581 // To fill the COO tensor from a sparse constant in COO format: 582 // for i in range(NNZ) 583 // val = values[i] 584 // [i1,..,ik] = indices[i] 585 // t->add(val, [i1,..,ik], [p1,..,pk]) 586 // 587 // Note that the dense tensor traversal code is actually implemented 588 // using MLIR IR to avoid having to expose too much low-level 589 // memref traversal details to the runtime support library. 590 // Also note that the code below only generates the "new" ops and 591 // the loop-nest per se; whereas the entire body of the innermost 592 // loop is generated by genAddElt(). 593 ShapedType stp = resType.cast<ShapedType>(); 594 unsigned rank = stp.getRank(); 595 SmallVector<Value, 4> sizes; 596 SmallVector<Value, 8> params; 597 sizesFromSrc(rewriter, sizes, loc, src); 598 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 599 Value ptr = genNewCall(rewriter, op, params); 600 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 601 Value perm = params[2]; 602 SmallVector<Value> lo; 603 SmallVector<Value> hi; 604 SmallVector<Value> st; 605 Value zero = constantIndex(rewriter, loc, 0); 606 Value one = constantIndex(rewriter, loc, 1); 607 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 608 bool isCOOConstant = indicesValues.hasValue(); 609 Value indices; 610 Value values; 611 if (isCOOConstant) { 612 indices = indicesValues->first; 613 values = indicesValues->second; 614 lo.push_back(zero); 615 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 616 st.push_back(one); 617 } else { 618 for (unsigned i = 0; i < rank; i++) { 619 lo.push_back(zero); 620 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 621 st.push_back(one); 622 } 623 } 624 Type eltType = stp.getElementType(); 625 scf::buildLoopNest( 626 rewriter, op.getLoc(), lo, hi, st, {}, 627 [&](OpBuilder &builder, Location loc, ValueRange ivs, 628 ValueRange args) -> scf::ValueVector { 629 Value val; 630 if (isCOOConstant) 631 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 632 ivs, rank); 633 else 634 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 635 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 636 return {}; 637 }); 638 // Final call to construct sparse tensor storage. 639 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 640 params[7] = ptr; 641 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 642 return success(); 643 } 644 }; 645 646 /// Sparse conversion rule for the release operator. 647 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 648 public: 649 using OpConversionPattern::OpConversionPattern; 650 LogicalResult 651 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 652 ConversionPatternRewriter &rewriter) const override { 653 StringRef name = "delSparseTensor"; 654 TypeRange noTp; 655 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 656 EmitCInterface::Off); 657 rewriter.eraseOp(op); 658 return success(); 659 } 660 }; 661 662 /// Sparse conversion rule for pointer accesses. 663 class SparseTensorToPointersConverter 664 : public OpConversionPattern<ToPointersOp> { 665 public: 666 using OpConversionPattern::OpConversionPattern; 667 LogicalResult 668 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 669 ConversionPatternRewriter &rewriter) const override { 670 Type resType = op.getType(); 671 Type eltType = resType.cast<ShapedType>().getElementType(); 672 StringRef name; 673 if (eltType.isIndex()) 674 name = "sparsePointers"; 675 else if (eltType.isInteger(64)) 676 name = "sparsePointers64"; 677 else if (eltType.isInteger(32)) 678 name = "sparsePointers32"; 679 else if (eltType.isInteger(16)) 680 name = "sparsePointers16"; 681 else if (eltType.isInteger(8)) 682 name = "sparsePointers8"; 683 else 684 return failure(); 685 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 686 EmitCInterface::On); 687 return success(); 688 } 689 }; 690 691 /// Sparse conversion rule for index accesses. 692 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 693 public: 694 using OpConversionPattern::OpConversionPattern; 695 LogicalResult 696 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 697 ConversionPatternRewriter &rewriter) const override { 698 Type resType = op.getType(); 699 Type eltType = resType.cast<ShapedType>().getElementType(); 700 StringRef name; 701 if (eltType.isIndex()) 702 name = "sparseIndices"; 703 else if (eltType.isInteger(64)) 704 name = "sparseIndices64"; 705 else if (eltType.isInteger(32)) 706 name = "sparseIndices32"; 707 else if (eltType.isInteger(16)) 708 name = "sparseIndices16"; 709 else if (eltType.isInteger(8)) 710 name = "sparseIndices8"; 711 else 712 return failure(); 713 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 714 EmitCInterface::On); 715 return success(); 716 } 717 }; 718 719 /// Sparse conversion rule for value accesses. 720 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 721 public: 722 using OpConversionPattern::OpConversionPattern; 723 LogicalResult 724 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 Type resType = op.getType(); 727 Type eltType = resType.cast<ShapedType>().getElementType(); 728 StringRef name; 729 if (eltType.isF64()) 730 name = "sparseValuesF64"; 731 else if (eltType.isF32()) 732 name = "sparseValuesF32"; 733 else if (eltType.isInteger(64)) 734 name = "sparseValuesI64"; 735 else if (eltType.isInteger(32)) 736 name = "sparseValuesI32"; 737 else if (eltType.isInteger(16)) 738 name = "sparseValuesI16"; 739 else if (eltType.isInteger(8)) 740 name = "sparseValuesI8"; 741 else 742 return failure(); 743 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 744 EmitCInterface::On); 745 return success(); 746 } 747 }; 748 749 /// Sparse conversion rule for tensor rematerialization. 750 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 751 public: 752 using OpConversionPattern::OpConversionPattern; 753 LogicalResult 754 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 755 ConversionPatternRewriter &rewriter) const override { 756 if (op.hasInserts()) { 757 // Finalize any pending insertions. 758 StringRef name = "endInsert"; 759 TypeRange noTp; 760 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 761 EmitCInterface::Off); 762 } 763 rewriter.replaceOp(op, adaptor.getOperands()); 764 return success(); 765 } 766 }; 767 768 /// Sparse conversion rule for inserting in lexicographic index order. 769 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 770 public: 771 using OpConversionPattern::OpConversionPattern; 772 LogicalResult 773 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 774 ConversionPatternRewriter &rewriter) const override { 775 Type srcType = op.tensor().getType(); 776 Type eltType = srcType.cast<ShapedType>().getElementType(); 777 StringRef name; 778 if (eltType.isF64()) 779 name = "lexInsertF64"; 780 else if (eltType.isF32()) 781 name = "lexInsertF32"; 782 else if (eltType.isInteger(64)) 783 name = "lexInsertI64"; 784 else if (eltType.isInteger(32)) 785 name = "lexInsertI32"; 786 else if (eltType.isInteger(16)) 787 name = "lexInsertI16"; 788 else if (eltType.isInteger(8)) 789 name = "lexInsertI8"; 790 else 791 llvm_unreachable("Unknown element type"); 792 TypeRange noTp; 793 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 794 EmitCInterface::On); 795 return success(); 796 } 797 }; 798 799 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 800 public: 801 using OpConversionPattern::OpConversionPattern; 802 LogicalResult 803 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 804 ConversionPatternRewriter &rewriter) const override { 805 Location loc = op->getLoc(); 806 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 807 Type eltType = srcType.getElementType(); 808 Type boolType = rewriter.getIntegerType(1); 809 Type idxType = rewriter.getIndexType(); 810 // All initialization should be done on entry of the loop nest. 811 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 812 // Determine the size for access expansion. 813 auto enc = getSparseTensorEncoding(srcType); 814 Value src = adaptor.getOperands()[0]; 815 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 816 // Allocate temporary stack buffers for values, filled-switch, and indices. 817 Value values = genAlloca(rewriter, loc, sz, eltType); 818 Value filled = genAlloca(rewriter, loc, sz, boolType); 819 Value indices = genAlloca(rewriter, loc, sz, idxType); 820 Value zero = constantZero(rewriter, loc, idxType); 821 // Reset the values/filled-switch to all-zero/false. Note that this 822 // introduces an O(N) operation into the computation, but this reset 823 // operation is amortized over the innermost loops for the access 824 // pattern expansion. 825 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType), 826 values); 827 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType), 828 filled); 829 // Replace expansion op with these buffers and initial index. 830 assert(op.getNumResults() == 4); 831 rewriter.replaceOp(op, {values, filled, indices, zero}); 832 return success(); 833 } 834 }; 835 836 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 837 public: 838 using OpConversionPattern::OpConversionPattern; 839 LogicalResult 840 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 841 ConversionPatternRewriter &rewriter) const override { 842 // Note that this method call resets the values/filled-switch back to 843 // all-zero/false by only iterating over the set elements, so the 844 // complexity remains proportional to the sparsity of the expanded 845 // access pattern. 846 Type srcType = op.tensor().getType(); 847 Type eltType = srcType.cast<ShapedType>().getElementType(); 848 StringRef name; 849 if (eltType.isF64()) 850 name = "expInsertF64"; 851 else if (eltType.isF32()) 852 name = "expInsertF32"; 853 else if (eltType.isInteger(64)) 854 name = "expInsertI64"; 855 else if (eltType.isInteger(32)) 856 name = "expInsertI32"; 857 else if (eltType.isInteger(16)) 858 name = "expInsertI16"; 859 else if (eltType.isInteger(8)) 860 name = "expInsertI8"; 861 else 862 return failure(); 863 TypeRange noTp; 864 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 865 EmitCInterface::On); 866 return success(); 867 } 868 }; 869 870 } // namespace 871 872 //===----------------------------------------------------------------------===// 873 // Public method for populating conversion rules. 874 //===----------------------------------------------------------------------===// 875 876 /// Populates the given patterns list with conversion rules required for 877 /// the sparsification of linear algebra operations. 878 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 879 RewritePatternSet &patterns) { 880 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 881 SparseCastConverter, SparseTensorNewConverter, 882 SparseTensorInitConverter, SparseTensorConvertConverter, 883 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 884 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 885 SparseTensorLoadConverter, SparseTensorLexInsertConverter, 886 SparseTensorExpandConverter, SparseTensorCompressConverter>( 887 typeConverter, patterns.getContext()); 888 } 889