1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 37 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 38 enum class EmitCInterface : bool { Off = false, On = true }; 39 40 //===----------------------------------------------------------------------===// 41 // Helper methods. 42 //===----------------------------------------------------------------------===// 43 44 /// Returns the equivalent of `void*` for opaque arguments to the 45 /// execution engine. 46 static Type getOpaquePointerType(OpBuilder &builder) { 47 return LLVM::LLVMPointerType::get(builder.getI8Type()); 48 } 49 50 /// Returns a function reference (first hit also inserts into module). Sets 51 /// the "_emit_c_interface" on the function declaration when requested, 52 /// so that LLVM lowering generates a wrapper function that takes care 53 /// of ABI complications with passing in and returning MemRefs to C functions. 54 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 55 TypeRange resultType, ValueRange operands, 56 EmitCInterface emitCInterface) { 57 MLIRContext *context = op->getContext(); 58 auto module = op->getParentOfType<ModuleOp>(); 59 auto result = SymbolRefAttr::get(context, name); 60 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 61 if (!func) { 62 OpBuilder moduleBuilder(module.getBodyRegion()); 63 func = moduleBuilder.create<func::FuncOp>( 64 op->getLoc(), name, 65 FunctionType::get(context, operands.getTypes(), resultType)); 66 func.setPrivate(); 67 if (static_cast<bool>(emitCInterface)) 68 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 69 } 70 return result; 71 } 72 73 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 74 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 75 StringRef name, TypeRange resultType, 76 ValueRange operands, 77 EmitCInterface emitCInterface) { 78 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 79 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 80 } 81 82 /// Replaces the `op` with a `CallOp` to the function reference returned 83 /// by `getFunc()`. 84 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 85 StringRef name, TypeRange resultType, 86 ValueRange operands, 87 EmitCInterface emitCInterface) { 88 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 89 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 90 operands); 91 } 92 93 /// Generates dimension size call. 94 static Value genDimSizeCall(OpBuilder &builder, Operation *op, 95 SparseTensorEncodingAttr &enc, Value src, 96 int64_t idx) { 97 // Permute the index according to an optional dimension ordering. 98 if (AffineMap p = enc.getDimOrdering()) 99 idx = p.getPermutedPosition(idx); 100 // Generate the call. 101 StringRef name = "sparseDimSize"; 102 SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)}; 103 Type iTp = builder.getIndexType(); 104 return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off) 105 .getResult(0); 106 } 107 108 /// Generates a call into the "swiss army knife" method of the sparse runtime 109 /// support library for materializing sparse tensors into the computation. 110 static Value genNewCall(OpBuilder &builder, Operation *op, 111 ArrayRef<Value> params) { 112 StringRef name = "newSparseTensor"; 113 Type pTp = getOpaquePointerType(builder); 114 return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On) 115 .getResult(0); 116 } 117 118 /// Populates given sizes array from type. 119 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes, 120 Location loc, ShapedType stp) { 121 auto shape = stp.getShape(); 122 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 123 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 124 sizes.push_back(constantIndex(builder, loc, s)); 125 } 126 } 127 128 /// Populates given sizes array from source. 129 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes, 130 Location loc, Value src) { 131 unsigned rank = src.getType().cast<ShapedType>().getRank(); 132 for (unsigned i = 0; i < rank; i++) 133 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); 134 } 135 136 /// Populates given sizes array from type (for static sizes) and from 137 /// an already converted into opague pointer source (for dynamic sizes). 138 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes, 139 Operation *op, SparseTensorEncodingAttr &enc, 140 ShapedType stp, Value src) { 141 Location loc = op->getLoc(); 142 auto shape = stp.getShape(); 143 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 144 if (shape[i] == ShapedType::kDynamicSize) 145 sizes.push_back(genDimSizeCall(builder, op, enc, src, i)); 146 else 147 sizes.push_back(constantIndex(builder, loc, shape[i])); 148 } 149 150 /// Generates an uninitialized temporary buffer of the given size and 151 /// type, but returns it as type `memref<? x $tp>` (rather than as type 152 /// `memref<$sz x $tp>`). 153 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { 154 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 155 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 156 } 157 158 /// Generates an uninitialized buffer of the given size and type, 159 /// but returns it as type `memref<? x $tp>` (rather than as type 160 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 161 /// this buffer must be explicitly deallocated by client. 162 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 163 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 164 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 165 } 166 167 /// Generates an uninitialized temporary buffer of the given size and 168 /// type, but returns it as type `memref<? x $tp>` (rather than as type 169 /// `memref<$sz x $tp>`). 170 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) { 171 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 172 } 173 174 /// Generates an uninitialized temporary buffer with room for one value 175 /// of the given type, and returns the `memref<$tp>`. 176 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { 177 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 178 } 179 180 /// Generates a temporary buffer of the given type and given contents. 181 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { 182 unsigned sz = values.size(); 183 assert(sz >= 1); 184 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 185 for (unsigned i = 0; i < sz; i++) { 186 Value idx = constantIndex(builder, loc, i); 187 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 188 } 189 return buffer; 190 } 191 192 /// Populates parameters required to call the "swiss army knife" method of the 193 /// sparse runtime support library for materializing sparse tensors into the 194 /// computation. 195 static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms, 196 Operation *op, ShapedType stp, 197 SparseTensorEncodingAttr &enc, Action action, 198 ValueRange szs, Value ptr = Value()) { 199 Location loc = op->getLoc(); 200 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 201 unsigned sz = dlt.size(); 202 // Sparsity annotations. 203 SmallVector<Value, 4> attrs; 204 for (unsigned i = 0; i < sz; i++) 205 attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); 206 params.push_back(genBuffer(builder, loc, attrs)); 207 // Dimension sizes array of the enveloping tensor. Useful for either 208 // verification of external data, or for construction of internal data. 209 params.push_back(genBuffer(builder, loc, szs)); 210 // Dimension order permutation array. This is the "identity" permutation by 211 // default, or otherwise the "reverse" permutation of a given ordering, so 212 // that indices can be mapped quickly to the right position. 213 SmallVector<Value, 4> rev(sz); 214 if (AffineMap p = enc.getDimOrdering()) { 215 for (unsigned i = 0; i < sz; i++) 216 rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); 217 } else { 218 for (unsigned i = 0; i < sz; i++) 219 rev[i] = constantIndex(builder, loc, i); 220 } 221 params.push_back(genBuffer(builder, loc, rev)); 222 // Secondary and primary types encoding. 223 Type elemTp = stp.getElementType(); 224 params.push_back(constantPointerTypeEncoding(builder, loc, enc)); 225 params.push_back(constantIndexTypeEncoding(builder, loc, enc)); 226 params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); 227 // User action. 228 params.push_back(constantAction(builder, loc, action)); 229 // Payload pointer. 230 if (!ptr) 231 ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder)); 232 params.push_back(ptr); 233 } 234 235 /// Generates the code to read the value from tensor[ivs], and conditionally 236 /// stores the indices ivs to the memory in ind. The generated code looks like 237 /// the following and the insertion point after this routine is inside the 238 /// if-then branch behind the assignment to ind. This is to ensure that the 239 /// addEltX call generated after is inside the if-then branch. 240 /// if (tensor[ivs]!=0) { 241 /// ind = ivs 242 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, 243 Value tensor, Value ind, ValueRange ivs) { 244 Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs); 245 Value cond = genIsNonzero(builder, loc, val); 246 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false); 247 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 248 unsigned i = 0; 249 for (auto iv : ivs) { 250 Value idx = constantIndex(builder, loc, i++); 251 builder.create<memref::StoreOp>(loc, iv, ind, idx); 252 } 253 return val; 254 } 255 256 /// Generates a call to release/delete a `SparseTensorCOO`. 257 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 258 Value coo) { 259 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 260 TypeRange noTp; 261 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 262 } 263 264 /// Generates a call that adds one element to a coordinate scheme. 265 /// In particular, this generates code like the following: 266 /// val = a[i1,..,ik]; 267 /// if val != 0 268 /// t->add(val, [i1,..,ik], [p1,..,pk]); 269 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType, 270 Value ptr, Value val, Value ind, Value perm) { 271 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 272 SmallVector<Value, 4> params{ptr, val, ind, perm}; 273 Type pTp = getOpaquePointerType(builder); 274 createFuncCall(builder, op, name, pTp, params, EmitCInterface::On); 275 } 276 277 /// Generates a call to `iter->getNext()`. If there is a next element, 278 /// then it is copied into the out-parameters `ind` and `elemPtr`, 279 /// and the return value is true. If there isn't a next element, then 280 /// the memory for `iter` is freed and the return value is false. 281 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter, 282 Value ind, Value elemPtr) { 283 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 284 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 285 SmallVector<Value, 3> params{iter, ind, elemPtr}; 286 Type i1 = builder.getI1Type(); 287 return createFuncCall(builder, op, name, i1, params, EmitCInterface::On) 288 .getResult(0); 289 } 290 291 /// If the tensor is a sparse constant, generates and returns the pair of 292 /// the constants for the indices and the values. 293 static Optional<std::pair<Value, Value>> 294 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { 295 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 296 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 297 DenseElementsAttr indicesAttr = attr.getIndices(); 298 Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr); 299 DenseElementsAttr valuesAttr = attr.getValues(); 300 Value values = builder.create<arith::ConstantOp>(loc, valuesAttr); 301 return std::make_pair(indices, values); 302 } 303 } 304 return {}; 305 } 306 307 /// Generates the code to copy the index at indices[ivs] to ind, and return 308 /// the value at value[ivs]. 309 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc, 310 Value indices, Value values, Value ind, 311 ValueRange ivs, unsigned rank) { 312 for (unsigned i = 0; i < rank; i++) { 313 Value idx = constantIndex(builder, loc, i); 314 Value val = builder.create<tensor::ExtractOp>(loc, indices, 315 ValueRange{ivs[0], idx}); 316 val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val); 317 builder.create<memref::StoreOp>(loc, val, ind, idx); 318 } 319 return builder.create<tensor::ExtractOp>(loc, values, ivs[0]); 320 } 321 322 /// Generates code to allocate a tensor of the given type, and zero 323 /// initialize it. If the tensor type has any dynamic sizes, then the 324 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 325 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 326 static Value allocDenseTensor(OpBuilder &builder, Location loc, 327 RankedTensorType tensorTp, ValueRange sizes) { 328 Type elemTp = tensorTp.getElementType(); 329 auto shape = tensorTp.getShape(); 330 auto memTp = MemRefType::get(shape, elemTp); 331 SmallVector<Value> dynamicSizes; 332 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 333 if (shape[i] == ShapedType::kDynamicSize) 334 dynamicSizes.push_back(sizes[i]); 335 } 336 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 337 Value zero = constantZero(builder, loc, elemTp); 338 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 339 return mem; 340 } 341 342 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 343 /// the tensor created by allocDenseTensor(). The `rank` is the rank 344 /// of the `tensor` and the length of `ind`. 345 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, 346 Value elemPtr, Value tensor, 347 unsigned rank, Value ind) { 348 SmallVector<Value, 4> ivs; 349 ivs.reserve(rank); 350 for (unsigned i = 0; i < rank; i++) { 351 Value idx = constantIndex(builder, loc, i); 352 ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx)); 353 } 354 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr); 355 builder.create<memref::StoreOp>(loc, elemV, tensor, ivs); 356 } 357 358 /// Determine if the runtime library supports direct conversion to the 359 /// given target `dimTypes`. 360 static bool canUseDirectConversion( 361 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) { 362 bool alreadyCompressed = false; 363 for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { 364 switch (dimTypes[r]) { 365 case SparseTensorEncodingAttr::DimLevelType::Compressed: 366 if (alreadyCompressed) 367 return false; // Multiple compressed dimensions not yet supported. 368 alreadyCompressed = true; 369 break; 370 case SparseTensorEncodingAttr::DimLevelType::Dense: 371 if (alreadyCompressed) 372 return false; // Dense after Compressed not yet supported. 373 break; 374 case SparseTensorEncodingAttr::DimLevelType::Singleton: 375 // Although Singleton isn't generally supported yet, the direct 376 // conversion method doesn't have any particular problems with 377 // singleton after compressed. 378 break; 379 } 380 } 381 return true; 382 } 383 384 //===----------------------------------------------------------------------===// 385 // Conversion rules. 386 //===----------------------------------------------------------------------===// 387 388 /// Sparse conversion rule for returns. 389 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 390 public: 391 using OpConversionPattern::OpConversionPattern; 392 LogicalResult 393 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const override { 395 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 396 return success(); 397 } 398 }; 399 400 /// Sparse conversion rule for dimension accesses. 401 class SparseTensorToDimSizeConverter 402 : public OpConversionPattern<tensor::DimOp> { 403 public: 404 using OpConversionPattern::OpConversionPattern; 405 LogicalResult 406 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override { 408 // Only rewrite annotated DimOp with constant index. 409 auto enc = getSparseTensorEncoding(op.source().getType()); 410 if (!enc) 411 return failure(); 412 Optional<int64_t> index = op.getConstantIndex(); 413 if (!index.hasValue()) 414 return failure(); 415 // Generate the call. 416 Value src = adaptor.getOperands()[0]; 417 int64_t idx = index.getValue(); 418 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 419 return success(); 420 } 421 }; 422 423 /// Sparse conversion rule for trivial tensor casts. 424 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 425 using OpConversionPattern::OpConversionPattern; 426 LogicalResult 427 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 428 ConversionPatternRewriter &rewriter) const override { 429 // Only rewrite identically annotated source/dest. 430 auto encDst = getSparseTensorEncoding(op.getType()); 431 auto encSrc = getSparseTensorEncoding(op.source().getType()); 432 if (!encDst || encDst != encSrc) 433 return failure(); 434 rewriter.replaceOp(op, adaptor.getOperands()); 435 return success(); 436 } 437 }; 438 439 /// Sparse conversion rule for the new operator. 440 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 441 using OpConversionPattern::OpConversionPattern; 442 LogicalResult 443 matchAndRewrite(NewOp op, OpAdaptor adaptor, 444 ConversionPatternRewriter &rewriter) const override { 445 Type resType = op.getType(); 446 auto enc = getSparseTensorEncoding(resType); 447 if (!enc) 448 return failure(); 449 // Generate the call to construct tensor from ptr. The sizes are 450 // inferred from the result type of the new operator. 451 SmallVector<Value, 4> sizes; 452 SmallVector<Value, 8> params; 453 ShapedType stp = resType.cast<ShapedType>(); 454 sizesFromType(rewriter, sizes, op.getLoc(), stp); 455 Value ptr = adaptor.getOperands()[0]; 456 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 457 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 458 return success(); 459 } 460 }; 461 462 /// Sparse conversion rule for the alloc operator. 463 class SparseTensorAllocConverter 464 : public OpConversionPattern<bufferization::AllocTensorOp> { 465 using OpConversionPattern::OpConversionPattern; 466 LogicalResult 467 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 468 ConversionPatternRewriter &rewriter) const override { 469 RankedTensorType resType = op.getType(); 470 auto enc = getSparseTensorEncoding(resType); 471 if (!enc) 472 return failure(); 473 // Gather all dimension sizes as SSA values. 474 SmallVector<Value> sizes; 475 unsigned int operandCtr = 0; 476 for (int64_t i = 0; i < resType.getRank(); ++i) { 477 if (resType.isDynamicDim(i)) { 478 sizes.push_back(adaptor.getOperands()[operandCtr++]); 479 } else { 480 sizes.push_back(rewriter.create<arith::ConstantIndexOp>( 481 op.getLoc(), op.getStaticSize(i))); 482 } 483 } 484 // Generate the call to construct empty tensor. The sizes are 485 // explicitly defined by the arguments to the alloc operator. 486 SmallVector<Value, 8> params; 487 ShapedType stp = resType.cast<ShapedType>(); 488 newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes); 489 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 490 return success(); 491 } 492 }; 493 494 /// Sparse conversion rule for the convert operator. 495 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 496 /// Options to control sparse code generation. 497 SparseTensorConversionOptions options; 498 499 public: 500 using OpConversionPattern::OpConversionPattern; 501 SparseTensorConvertConverter(MLIRContext *context, 502 SparseTensorConversionOptions o) 503 : OpConversionPattern<ConvertOp>(context), options(o) {} 504 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 505 SparseTensorConversionOptions o) 506 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 507 508 LogicalResult 509 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 510 ConversionPatternRewriter &rewriter) const override { 511 Location loc = op->getLoc(); 512 Type resType = op.getType(); 513 Type srcType = op.source().getType(); 514 auto encDst = getSparseTensorEncoding(resType); 515 auto encSrc = getSparseTensorEncoding(srcType); 516 Value src = adaptor.getOperands()[0]; 517 if (encDst && encSrc) { 518 // This is a sparse => sparse conversion, which is handled as follows: 519 // t = src->toCOO(); ; src to COO in dst order 520 // dst = newSparseTensor(t) 521 // Using the coordinate scheme as an intermediate does not always 522 // yield the fastest conversion but avoids the need for a full 523 // O(N^2) conversion matrix. 524 if (encDst == encSrc) { 525 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 526 return success(); 527 } 528 SmallVector<Value, 4> sizes; 529 SmallVector<Value, 8> params; 530 ShapedType stp = srcType.cast<ShapedType>(); 531 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 532 bool useDirectConversion; 533 switch (options.sparseToSparseStrategy) { 534 case SparseToSparseConversionStrategy::kViaCOO: 535 useDirectConversion = false; 536 break; 537 case SparseToSparseConversionStrategy::kDirect: 538 useDirectConversion = true; 539 assert(canUseDirectConversion(encDst.getDimLevelType()) && 540 "Unsupported target for direct sparse-to-sparse conversion"); 541 break; 542 case SparseToSparseConversionStrategy::kAuto: 543 useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); 544 break; 545 } 546 if (useDirectConversion) { 547 newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse, 548 sizes, src); 549 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 550 } else { // use via-COO conversion. 551 // Set up encoding with right mix of src and dst so that the two 552 // method calls can share most parameters, while still providing 553 // the correct sparsity information to either of them. 554 auto enc = SparseTensorEncodingAttr::get( 555 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 556 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 557 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 558 Value coo = genNewCall(rewriter, op, params); 559 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 560 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 561 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 562 params[7] = coo; 563 Value dst = genNewCall(rewriter, op, params); 564 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 565 rewriter.replaceOp(op, dst); 566 } 567 return success(); 568 } 569 if (!encDst && encSrc) { 570 // This is sparse => dense conversion, which is handled as follows: 571 // dst = new Tensor(0); 572 // iter = src->toCOO(); 573 // iter->startIterator(); 574 // while (elem = iter->getNext()) { 575 // dst[elem.indices] = elem.value; 576 // } 577 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 578 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 579 unsigned rank = dstTensorTp.getRank(); 580 Type elemTp = dstTensorTp.getElementType(); 581 // Fabricate a no-permutation encoding for newParams(). 582 // The pointer/index types must be those of `src`. 583 // The dimLevelTypes aren't actually used by Action::kToIterator. 584 encDst = SparseTensorEncodingAttr::get( 585 op->getContext(), 586 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 587 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 588 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 589 SmallVector<Value, 4> sizes; 590 SmallVector<Value, 8> params; 591 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 592 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 593 sizes, src); 594 Value iter = genNewCall(rewriter, op, params); 595 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 596 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 597 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 598 SmallVector<Value> noArgs; 599 SmallVector<Type> noTypes; 600 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 601 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 602 rewriter.setInsertionPointToEnd(before); 603 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 604 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 605 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 606 rewriter.setInsertionPointToStart(after); 607 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 608 rewriter.create<scf::YieldOp>(loc); 609 rewriter.setInsertionPointAfter(whileOp); 610 genDelCOOCall(rewriter, op, elemTp, iter); 611 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 612 return success(); 613 } 614 if (!encDst && !encSrc) { 615 // dense => dense 616 return failure(); 617 } 618 // This is a dense => sparse conversion or a sparse constant in COO => 619 // sparse conversion, which is handled as follows: 620 // t = newSparseCOO() 621 // ...code to fill the COO tensor t... 622 // s = newSparseTensor(t) 623 // 624 // To fill the COO tensor from a dense tensor: 625 // for i1 in dim1 626 // .. 627 // for ik in dimk 628 // val = a[i1,..,ik] 629 // if val != 0 630 // t->add(val, [i1,..,ik], [p1,..,pk]) 631 // 632 // To fill the COO tensor from a sparse constant in COO format: 633 // for i in range(NNZ) 634 // val = values[i] 635 // [i1,..,ik] = indices[i] 636 // t->add(val, [i1,..,ik], [p1,..,pk]) 637 // 638 // Note that the dense tensor traversal code is actually implemented 639 // using MLIR IR to avoid having to expose too much low-level 640 // memref traversal details to the runtime support library. 641 // Also note that the code below only generates the "new" ops and 642 // the loop-nest per se; whereas the entire body of the innermost 643 // loop is generated by genAddElt(). 644 ShapedType stp = resType.cast<ShapedType>(); 645 unsigned rank = stp.getRank(); 646 SmallVector<Value, 4> sizes; 647 SmallVector<Value, 8> params; 648 sizesFromSrc(rewriter, sizes, loc, src); 649 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 650 Value coo = genNewCall(rewriter, op, params); 651 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 652 Value perm = params[2]; 653 SmallVector<Value> lo; 654 SmallVector<Value> hi; 655 SmallVector<Value> st; 656 Value zero = constantIndex(rewriter, loc, 0); 657 Value one = constantIndex(rewriter, loc, 1); 658 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 659 bool isCOOConstant = indicesValues.hasValue(); 660 Value indices; 661 Value values; 662 if (isCOOConstant) { 663 indices = indicesValues->first; 664 values = indicesValues->second; 665 lo.push_back(zero); 666 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 667 st.push_back(one); 668 } else { 669 for (unsigned i = 0; i < rank; i++) { 670 lo.push_back(zero); 671 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 672 st.push_back(one); 673 } 674 } 675 Type eltType = stp.getElementType(); 676 scf::buildLoopNest( 677 rewriter, op.getLoc(), lo, hi, st, {}, 678 [&](OpBuilder &builder, Location loc, ValueRange ivs, 679 ValueRange args) -> scf::ValueVector { 680 Value val; 681 if (isCOOConstant) 682 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 683 ivs, rank); 684 else 685 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 686 genAddEltCall(rewriter, op, eltType, coo, val, ind, perm); 687 return {}; 688 }); 689 // Final call to construct sparse tensor storage. 690 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 691 params[7] = coo; 692 Value dst = genNewCall(rewriter, op, params); 693 genDelCOOCall(rewriter, op, eltType, coo); 694 rewriter.replaceOp(op, dst); 695 return success(); 696 } 697 }; 698 699 /// Sparse conversion rule for the release operator. 700 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 701 public: 702 using OpConversionPattern::OpConversionPattern; 703 LogicalResult 704 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 705 ConversionPatternRewriter &rewriter) const override { 706 StringRef name = "delSparseTensor"; 707 TypeRange noTp; 708 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 709 EmitCInterface::Off); 710 rewriter.eraseOp(op); 711 return success(); 712 } 713 }; 714 715 /// Sparse conversion rule for pointer accesses. 716 class SparseTensorToPointersConverter 717 : public OpConversionPattern<ToPointersOp> { 718 public: 719 using OpConversionPattern::OpConversionPattern; 720 LogicalResult 721 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 722 ConversionPatternRewriter &rewriter) const override { 723 Type resType = op.getType(); 724 Type ptrType = resType.cast<ShapedType>().getElementType(); 725 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 726 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 727 EmitCInterface::On); 728 return success(); 729 } 730 }; 731 732 /// Sparse conversion rule for index accesses. 733 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 734 public: 735 using OpConversionPattern::OpConversionPattern; 736 LogicalResult 737 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 738 ConversionPatternRewriter &rewriter) const override { 739 Type resType = op.getType(); 740 Type indType = resType.cast<ShapedType>().getElementType(); 741 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 742 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 743 EmitCInterface::On); 744 return success(); 745 } 746 }; 747 748 /// Sparse conversion rule for value accesses. 749 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 750 public: 751 using OpConversionPattern::OpConversionPattern; 752 LogicalResult 753 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 754 ConversionPatternRewriter &rewriter) const override { 755 Type resType = op.getType(); 756 Type eltType = resType.cast<ShapedType>().getElementType(); 757 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 758 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 759 EmitCInterface::On); 760 return success(); 761 } 762 }; 763 764 /// Sparse conversion rule for tensor rematerialization. 765 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 766 public: 767 using OpConversionPattern::OpConversionPattern; 768 LogicalResult 769 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 770 ConversionPatternRewriter &rewriter) const override { 771 if (op.hasInserts()) { 772 // Finalize any pending insertions. 773 StringRef name = "endInsert"; 774 TypeRange noTp; 775 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 776 EmitCInterface::Off); 777 } 778 rewriter.replaceOp(op, adaptor.getOperands()); 779 return success(); 780 } 781 }; 782 783 /// Sparse conversion rule for inserting in lexicographic index order. 784 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 785 public: 786 using OpConversionPattern::OpConversionPattern; 787 LogicalResult 788 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 789 ConversionPatternRewriter &rewriter) const override { 790 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 791 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 792 TypeRange noTp; 793 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 794 EmitCInterface::On); 795 return success(); 796 } 797 }; 798 799 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 800 public: 801 using OpConversionPattern::OpConversionPattern; 802 LogicalResult 803 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 804 ConversionPatternRewriter &rewriter) const override { 805 Location loc = op->getLoc(); 806 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 807 Type eltType = srcType.getElementType(); 808 Type boolType = rewriter.getIntegerType(1); 809 Type idxType = rewriter.getIndexType(); 810 // All initialization should be done on entry of the loop nest. 811 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 812 // Determine the size for access expansion. 813 auto enc = getSparseTensorEncoding(srcType); 814 Value src = adaptor.getOperands()[0]; 815 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 816 // Allocate temporary buffers for values, filled-switch, and indices. 817 // We do not use stack buffers for this, since the expanded size may 818 // be rather large (as it envelops a single expanded dense dimension). 819 Value values = genAlloc(rewriter, loc, sz, eltType); 820 Value filled = genAlloc(rewriter, loc, sz, boolType); 821 Value indices = genAlloc(rewriter, loc, sz, idxType); 822 Value zero = constantZero(rewriter, loc, idxType); 823 // Reset the values/filled-switch to all-zero/false. Note that this 824 // introduces an O(N) operation into the computation, but this reset 825 // operation is amortized over the innermost loops for the access 826 // pattern expansion. As noted in the operation doc, we would like 827 // to amortize this setup cost even between kernels. 828 rewriter.create<linalg::FillOp>( 829 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 830 ValueRange{values}); 831 rewriter.create<linalg::FillOp>( 832 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 833 ValueRange{filled}); 834 // Replace expansion op with these buffers and initial index. 835 assert(op.getNumResults() == 4); 836 rewriter.replaceOp(op, {values, filled, indices, zero}); 837 return success(); 838 } 839 }; 840 841 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 842 public: 843 using OpConversionPattern::OpConversionPattern; 844 LogicalResult 845 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 846 ConversionPatternRewriter &rewriter) const override { 847 Location loc = op->getLoc(); 848 // Note that this method call resets the values/filled-switch back to 849 // all-zero/false by only iterating over the set elements, so the 850 // complexity remains proportional to the sparsity of the expanded 851 // access pattern. 852 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 853 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 854 TypeRange noTp; 855 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 856 EmitCInterface::On); 857 // Deallocate the buffers on exit of the loop nest. 858 Operation *parent = op; 859 for (; isa<scf::ForOp>(parent->getParentOp()) || 860 isa<scf::WhileOp>(parent->getParentOp()) || 861 isa<scf::ParallelOp>(parent->getParentOp()) || 862 isa<scf::IfOp>(parent->getParentOp()); 863 parent = parent->getParentOp()) 864 ; 865 rewriter.setInsertionPointAfter(parent); 866 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]); 867 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]); 868 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]); 869 return success(); 870 } 871 }; 872 873 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 874 public: 875 using OpConversionPattern::OpConversionPattern; 876 LogicalResult 877 matchAndRewrite(OutOp op, OpAdaptor adaptor, 878 ConversionPatternRewriter &rewriter) const override { 879 Location loc = op->getLoc(); 880 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 881 // Convert to default permuted COO. 882 Value src = adaptor.getOperands()[0]; 883 auto encSrc = getSparseTensorEncoding(srcType); 884 SmallVector<Value, 4> sizes; 885 SmallVector<Value, 8> params; 886 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 887 auto enc = SparseTensorEncodingAttr::get( 888 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 889 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 890 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 891 Value coo = genNewCall(rewriter, op, params); 892 // Then output the tensor to external file with indices in the externally 893 // visible lexicographic index order. A sort is required if the source was 894 // not in that order yet (note that the sort can be dropped altogether if 895 // external format does not care about the order at all, but here we assume 896 // it does). 897 bool sort = 898 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 899 params.clear(); 900 params.push_back(coo); 901 params.push_back(adaptor.getOperands()[1]); 902 params.push_back(constantI1(rewriter, loc, sort)); 903 Type eltType = srcType.getElementType(); 904 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 905 TypeRange noTp; 906 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 907 genDelCOOCall(rewriter, op, eltType, coo); 908 rewriter.eraseOp(op); 909 return success(); 910 } 911 }; 912 913 } // namespace 914 915 //===----------------------------------------------------------------------===// 916 // Public method for populating conversion rules. 917 //===----------------------------------------------------------------------===// 918 919 /// Populates the given patterns list with conversion rules required for 920 /// the sparsification of linear algebra operations. 921 void mlir::populateSparseTensorConversionPatterns( 922 TypeConverter &typeConverter, RewritePatternSet &patterns, 923 const SparseTensorConversionOptions &options) { 924 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 925 SparseCastConverter, SparseTensorNewConverter, 926 SparseTensorAllocConverter, SparseTensorReleaseConverter, 927 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 928 SparseTensorToValuesConverter, SparseTensorLoadConverter, 929 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 930 SparseTensorCompressConverter, SparseTensorOutConverter>( 931 typeConverter, patterns.getContext()); 932 patterns.add<SparseTensorConvertConverter>(typeConverter, 933 patterns.getContext(), options); 934 } 935