1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 37 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 38 enum class EmitCInterface : bool { Off = false, On = true }; 39 40 //===----------------------------------------------------------------------===// 41 // Helper methods. 42 //===----------------------------------------------------------------------===// 43 44 /// Returns the equivalent of `void*` for opaque arguments to the 45 /// execution engine. 46 static Type getOpaquePointerType(PatternRewriter &rewriter) { 47 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 48 } 49 50 /// Returns a function reference (first hit also inserts into module). Sets 51 /// the "_emit_c_interface" on the function declaration when requested, 52 /// so that LLVM lowering generates a wrapper function that takes care 53 /// of ABI complications with passing in and returning MemRefs to C functions. 54 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 55 TypeRange resultType, ValueRange operands, 56 EmitCInterface emitCInterface) { 57 MLIRContext *context = op->getContext(); 58 auto module = op->getParentOfType<ModuleOp>(); 59 auto result = SymbolRefAttr::get(context, name); 60 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 61 if (!func) { 62 OpBuilder moduleBuilder(module.getBodyRegion()); 63 func = moduleBuilder.create<FuncOp>( 64 op->getLoc(), name, 65 FunctionType::get(context, operands.getTypes(), resultType)); 66 func.setPrivate(); 67 if (static_cast<bool>(emitCInterface)) 68 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 69 } 70 return result; 71 } 72 73 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 74 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 75 StringRef name, TypeRange resultType, 76 ValueRange operands, 77 EmitCInterface emitCInterface) { 78 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 79 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 80 } 81 82 /// Replaces the `op` with a `CallOp` to the function reference returned 83 /// by `getFunc()`. 84 static func::CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, 85 Operation *op, StringRef name, 86 TypeRange resultType, 87 ValueRange operands, 88 EmitCInterface emitCInterface) { 89 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 90 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 91 operands); 92 } 93 94 /// Generates dimension size call. 95 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 96 SparseTensorEncodingAttr &enc, Value src, 97 int64_t idx) { 98 // Permute the index according to an optional dimension ordering. 99 if (AffineMap p = enc.getDimOrdering()) 100 idx = p.getPermutedPosition(idx); 101 // Generate the call. 102 StringRef name = "sparseDimSize"; 103 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 104 Type iTp = rewriter.getIndexType(); 105 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 106 .getResult(0); 107 } 108 109 /// Generates a call into the "swiss army knife" method of the sparse runtime 110 /// support library for materializing sparse tensors into the computation. 111 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 112 ArrayRef<Value> params) { 113 StringRef name = "newSparseTensor"; 114 Type pTp = getOpaquePointerType(rewriter); 115 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 116 .getResult(0); 117 } 118 119 /// Populates given sizes array from type. 120 static void sizesFromType(ConversionPatternRewriter &rewriter, 121 SmallVector<Value, 4> &sizes, Location loc, 122 ShapedType stp) { 123 auto shape = stp.getShape(); 124 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 125 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 126 sizes.push_back(constantIndex(rewriter, loc, s)); 127 } 128 } 129 130 /// Populates given sizes array from source. 131 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 132 SmallVector<Value, 4> &sizes, Location loc, 133 Value src) { 134 unsigned rank = src.getType().cast<ShapedType>().getRank(); 135 for (unsigned i = 0; i < rank; i++) 136 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 137 } 138 139 /// Populates given sizes array from type (for static sizes) and from 140 /// an already converted into opague pointer source (for dynamic sizes). 141 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 142 SmallVector<Value, 4> &sizes, Operation *op, 143 SparseTensorEncodingAttr &enc, ShapedType stp, 144 Value src) { 145 Location loc = op->getLoc(); 146 auto shape = stp.getShape(); 147 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 148 if (shape[i] == ShapedType::kDynamicSize) 149 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 150 else 151 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 152 } 153 154 /// Generates an uninitialized temporary buffer of the given size and 155 /// type, but returns it as type `memref<? x $tp>` (rather than as type 156 /// `memref<$sz x $tp>`). 157 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 158 Value sz, Type tp) { 159 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 160 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 161 } 162 163 /// Generates an uninitialized temporary buffer of the given size and 164 /// type, but returns it as type `memref<? x $tp>` (rather than as type 165 /// `memref<$sz x $tp>`). 166 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 167 unsigned sz, Type tp) { 168 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 169 } 170 171 /// Generates an uninitialized temporary buffer with room for one value 172 /// of the given type, and returns the `memref<$tp>`. 173 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 174 Type tp) { 175 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 176 } 177 178 /// Generates a temporary buffer of the given type and given contents. 179 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 180 ValueRange values) { 181 unsigned sz = values.size(); 182 assert(sz >= 1); 183 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 184 for (unsigned i = 0; i < sz; i++) { 185 Value idx = constantIndex(rewriter, loc, i); 186 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 187 } 188 return buffer; 189 } 190 191 /// Populates parameters required to call the "swiss army knife" method of the 192 /// sparse runtime support library for materializing sparse tensors into the 193 /// computation. 194 static void newParams(ConversionPatternRewriter &rewriter, 195 SmallVector<Value, 8> ¶ms, Operation *op, 196 ShapedType stp, SparseTensorEncodingAttr &enc, 197 Action action, ValueRange szs, Value ptr = Value()) { 198 Location loc = op->getLoc(); 199 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 200 unsigned sz = dlt.size(); 201 // Sparsity annotations. 202 SmallVector<Value, 4> attrs; 203 for (unsigned i = 0; i < sz; i++) 204 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 205 params.push_back(genBuffer(rewriter, loc, attrs)); 206 // Dimension sizes array of the enveloping tensor. Useful for either 207 // verification of external data, or for construction of internal data. 208 params.push_back(genBuffer(rewriter, loc, szs)); 209 // Dimension order permutation array. This is the "identity" permutation by 210 // default, or otherwise the "reverse" permutation of a given ordering, so 211 // that indices can be mapped quickly to the right position. 212 SmallVector<Value, 4> rev(sz); 213 if (AffineMap p = enc.getDimOrdering()) { 214 for (unsigned i = 0; i < sz; i++) 215 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 216 } else { 217 for (unsigned i = 0; i < sz; i++) 218 rev[i] = constantIndex(rewriter, loc, i); 219 } 220 params.push_back(genBuffer(rewriter, loc, rev)); 221 // Secondary and primary types encoding. 222 Type elemTp = stp.getElementType(); 223 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 224 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 225 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 226 // User action. 227 params.push_back(constantAction(rewriter, loc, action)); 228 // Payload pointer. 229 if (!ptr) 230 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 231 params.push_back(ptr); 232 } 233 234 /// Generates the code to read the value from tensor[ivs], and conditionally 235 /// stores the indices ivs to the memory in ind. The generated code looks like 236 /// the following and the insertion point after this routine is inside the 237 /// if-then branch behind the assignment to ind. This is to ensure that the 238 /// addEltX call generated after is inside the if-then branch. 239 /// if (tensor[ivs]!=0) { 240 /// ind = ivs 241 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 242 Location loc, Value tensor, Value ind, 243 ValueRange ivs) { 244 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 245 Value cond = genIsNonzero(rewriter, loc, val); 246 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 247 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 248 unsigned i = 0; 249 for (auto iv : ivs) { 250 Value idx = constantIndex(rewriter, loc, i++); 251 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 252 } 253 return val; 254 } 255 256 /// Generates a call to release/delete a `SparseTensorCOO`. 257 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 258 Value coo) { 259 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 260 TypeRange noTp; 261 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 262 } 263 264 /// Generates a call that adds one element to a coordinate scheme. 265 /// In particular, this generates code like the following: 266 /// val = a[i1,..,ik]; 267 /// if val != 0 268 /// t->add(val, [i1,..,ik], [p1,..,pk]); 269 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 270 Type eltType, Value ptr, Value val, Value ind, 271 Value perm) { 272 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 273 SmallVector<Value, 4> params{ptr, val, ind, perm}; 274 Type pTp = getOpaquePointerType(rewriter); 275 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 276 } 277 278 /// Generates a call to `iter->getNext()`. If there is a next element, 279 /// then it is copied into the out-parameters `ind` and `elemPtr`, 280 /// and the return value is true. If there isn't a next element, then 281 /// the memory for `iter` is freed and the return value is false. 282 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 283 Value iter, Value ind, Value elemPtr) { 284 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 285 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 286 SmallVector<Value, 3> params{iter, ind, elemPtr}; 287 Type i1 = rewriter.getI1Type(); 288 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 289 .getResult(0); 290 } 291 292 /// If the tensor is a sparse constant, generates and returns the pair of 293 /// the constants for the indices and the values. 294 static Optional<std::pair<Value, Value>> 295 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 296 Value tensor) { 297 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 298 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 299 DenseElementsAttr indicesAttr = attr.getIndices(); 300 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 301 DenseElementsAttr valuesAttr = attr.getValues(); 302 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 303 return std::make_pair(indices, values); 304 } 305 } 306 return {}; 307 } 308 309 /// Generates the code to copy the index at indices[ivs] to ind, and return 310 /// the value at value[ivs]. 311 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 312 Location loc, Value indices, 313 Value values, Value ind, ValueRange ivs, 314 unsigned rank) { 315 for (unsigned i = 0; i < rank; i++) { 316 Value idx = constantIndex(rewriter, loc, i); 317 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 318 ValueRange{ivs[0], idx}); 319 val = 320 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val); 321 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 322 } 323 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 324 } 325 326 /// Generates code to allocate a tensor of the given type, and zero 327 /// initialize it. If the tensor type has any dynamic sizes, then the 328 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 329 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 330 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 331 RankedTensorType tensorTp, ValueRange sizes) { 332 Type elemTp = tensorTp.getElementType(); 333 auto shape = tensorTp.getShape(); 334 auto memTp = MemRefType::get(shape, elemTp); 335 SmallVector<Value> dynamicSizes; 336 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 337 if (shape[i] == ShapedType::kDynamicSize) 338 dynamicSizes.push_back(sizes[i]); 339 } 340 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 341 Value zero = constantZero(rewriter, loc, elemTp); 342 rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 343 return mem; 344 } 345 346 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 347 /// the tensor created by allocDenseTensor(). The `rank` is the rank 348 /// of the `tensor` and the length of `ind`. 349 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 350 Location loc, Value elemPtr, 351 Value tensor, unsigned rank, 352 Value ind) { 353 SmallVector<Value, 4> ivs; 354 ivs.reserve(rank); 355 for (unsigned i = 0; i < rank; i++) { 356 Value idx = constantIndex(rewriter, loc, i); 357 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 358 } 359 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 360 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Conversion rules. 365 //===----------------------------------------------------------------------===// 366 367 /// Sparse conversion rule for returns. 368 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 369 public: 370 using OpConversionPattern::OpConversionPattern; 371 LogicalResult 372 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 373 ConversionPatternRewriter &rewriter) const override { 374 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 375 return success(); 376 } 377 }; 378 379 /// Sparse conversion rule for dimension accesses. 380 class SparseTensorToDimSizeConverter 381 : public OpConversionPattern<tensor::DimOp> { 382 public: 383 using OpConversionPattern::OpConversionPattern; 384 LogicalResult 385 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 386 ConversionPatternRewriter &rewriter) const override { 387 // Only rewrite annotated DimOp with constant index. 388 auto enc = getSparseTensorEncoding(op.source().getType()); 389 if (!enc) 390 return failure(); 391 Optional<int64_t> index = op.getConstantIndex(); 392 if (!index.hasValue()) 393 return failure(); 394 // Generate the call. 395 Value src = adaptor.getOperands()[0]; 396 int64_t idx = index.getValue(); 397 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 398 return success(); 399 } 400 }; 401 402 /// Sparse conversion rule for trivial tensor casts. 403 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 404 using OpConversionPattern::OpConversionPattern; 405 LogicalResult 406 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override { 408 // Only rewrite identically annotated source/dest. 409 auto encDst = getSparseTensorEncoding(op.getType()); 410 auto encSrc = getSparseTensorEncoding(op.source().getType()); 411 if (!encDst || encDst != encSrc) 412 return failure(); 413 rewriter.replaceOp(op, adaptor.getOperands()); 414 return success(); 415 } 416 }; 417 418 /// Sparse conversion rule for the new operator. 419 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 420 using OpConversionPattern::OpConversionPattern; 421 LogicalResult 422 matchAndRewrite(NewOp op, OpAdaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 Type resType = op.getType(); 425 auto enc = getSparseTensorEncoding(resType); 426 if (!enc) 427 return failure(); 428 // Generate the call to construct tensor from ptr. The sizes are 429 // inferred from the result type of the new operator. 430 SmallVector<Value, 4> sizes; 431 SmallVector<Value, 8> params; 432 ShapedType stp = resType.cast<ShapedType>(); 433 sizesFromType(rewriter, sizes, op.getLoc(), stp); 434 Value ptr = adaptor.getOperands()[0]; 435 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 436 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 437 return success(); 438 } 439 }; 440 441 /// Sparse conversion rule for the init operator. 442 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 443 using OpConversionPattern::OpConversionPattern; 444 LogicalResult 445 matchAndRewrite(InitOp op, OpAdaptor adaptor, 446 ConversionPatternRewriter &rewriter) const override { 447 Type resType = op.getType(); 448 auto enc = getSparseTensorEncoding(resType); 449 if (!enc) 450 return failure(); 451 // Generate the call to construct empty tensor. The sizes are 452 // explicitly defined by the arguments to the init operator. 453 SmallVector<Value, 8> params; 454 ShapedType stp = resType.cast<ShapedType>(); 455 newParams(rewriter, params, op, stp, enc, Action::kEmpty, 456 adaptor.getOperands()); 457 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 458 return success(); 459 } 460 }; 461 462 /// Sparse conversion rule for the convert operator. 463 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 464 /// Options to control sparse code generation. 465 SparseTensorConversionOptions options; 466 467 public: 468 using OpConversionPattern::OpConversionPattern; 469 SparseTensorConvertConverter(MLIRContext *context, 470 SparseTensorConversionOptions o) 471 : OpConversionPattern<ConvertOp>(context), options(o) {} 472 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 473 SparseTensorConversionOptions o) 474 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 475 476 LogicalResult 477 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 478 ConversionPatternRewriter &rewriter) const override { 479 Location loc = op->getLoc(); 480 Type resType = op.getType(); 481 Type srcType = op.source().getType(); 482 auto encDst = getSparseTensorEncoding(resType); 483 auto encSrc = getSparseTensorEncoding(srcType); 484 Value src = adaptor.getOperands()[0]; 485 if (encDst && encSrc) { 486 // This is a sparse => sparse conversion, which is handled as follows: 487 // t = src->toCOO(); ; src to COO in dst order 488 // dst = newSparseTensor(t) 489 // Using the coordinate scheme as an intermediate does not always 490 // yield the fastest conversion but avoids the need for a full 491 // O(N^2) conversion matrix. 492 if (encDst == encSrc) { 493 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 494 return success(); 495 } 496 SmallVector<Value, 4> sizes; 497 SmallVector<Value, 8> params; 498 ShapedType stp = srcType.cast<ShapedType>(); 499 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 500 // Set up encoding with right mix of src and dst so that the two 501 // method calls can share most parameters, while still providing 502 // the correct sparsity information to either of them. 503 auto enc = SparseTensorEncodingAttr::get( 504 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 505 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 506 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 507 Value coo = genNewCall(rewriter, op, params); 508 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 509 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 510 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 511 params[7] = coo; 512 Value dst = genNewCall(rewriter, op, params); 513 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 514 rewriter.replaceOp(op, dst); 515 return success(); 516 } 517 if (!encDst && encSrc) { 518 // This is sparse => dense conversion, which is handled as follows: 519 // dst = new Tensor(0); 520 // iter = src->toCOO(); 521 // iter->startIterator(); 522 // while (elem = iter->getNext()) { 523 // dst[elem.indices] = elem.value; 524 // } 525 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 526 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 527 unsigned rank = dstTensorTp.getRank(); 528 Type elemTp = dstTensorTp.getElementType(); 529 // Fabricate a no-permutation encoding for newParams(). 530 // The pointer/index types must be those of `src`. 531 // The dimLevelTypes aren't actually used by Action::kToIterator. 532 encDst = SparseTensorEncodingAttr::get( 533 op->getContext(), 534 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 535 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 536 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 537 SmallVector<Value, 4> sizes; 538 SmallVector<Value, 8> params; 539 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 540 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 541 sizes, src); 542 Value iter = genNewCall(rewriter, op, params); 543 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 544 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 545 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 546 SmallVector<Value> noArgs; 547 SmallVector<Type> noTypes; 548 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 549 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 550 rewriter.setInsertionPointToEnd(before); 551 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 552 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 553 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 554 rewriter.setInsertionPointToStart(after); 555 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 556 rewriter.create<scf::YieldOp>(loc); 557 rewriter.setInsertionPointAfter(whileOp); 558 genDelCOOCall(rewriter, op, elemTp, iter); 559 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 560 return success(); 561 } 562 if (!encDst && !encSrc) { 563 // dense => dense 564 return failure(); 565 } 566 // This is a dense => sparse conversion or a sparse constant in COO => 567 // sparse conversion, which is handled as follows: 568 // t = newSparseCOO() 569 // ...code to fill the COO tensor t... 570 // s = newSparseTensor(t) 571 // 572 // To fill the COO tensor from a dense tensor: 573 // for i1 in dim1 574 // .. 575 // for ik in dimk 576 // val = a[i1,..,ik] 577 // if val != 0 578 // t->add(val, [i1,..,ik], [p1,..,pk]) 579 // 580 // To fill the COO tensor from a sparse constant in COO format: 581 // for i in range(NNZ) 582 // val = values[i] 583 // [i1,..,ik] = indices[i] 584 // t->add(val, [i1,..,ik], [p1,..,pk]) 585 // 586 // Note that the dense tensor traversal code is actually implemented 587 // using MLIR IR to avoid having to expose too much low-level 588 // memref traversal details to the runtime support library. 589 // Also note that the code below only generates the "new" ops and 590 // the loop-nest per se; whereas the entire body of the innermost 591 // loop is generated by genAddElt(). 592 ShapedType stp = resType.cast<ShapedType>(); 593 unsigned rank = stp.getRank(); 594 SmallVector<Value, 4> sizes; 595 SmallVector<Value, 8> params; 596 sizesFromSrc(rewriter, sizes, loc, src); 597 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 598 Value coo = genNewCall(rewriter, op, params); 599 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 600 Value perm = params[2]; 601 SmallVector<Value> lo; 602 SmallVector<Value> hi; 603 SmallVector<Value> st; 604 Value zero = constantIndex(rewriter, loc, 0); 605 Value one = constantIndex(rewriter, loc, 1); 606 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 607 bool isCOOConstant = indicesValues.hasValue(); 608 Value indices; 609 Value values; 610 if (isCOOConstant) { 611 indices = indicesValues->first; 612 values = indicesValues->second; 613 lo.push_back(zero); 614 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 615 st.push_back(one); 616 } else { 617 for (unsigned i = 0; i < rank; i++) { 618 lo.push_back(zero); 619 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 620 st.push_back(one); 621 } 622 } 623 Type eltType = stp.getElementType(); 624 scf::buildLoopNest( 625 rewriter, op.getLoc(), lo, hi, st, {}, 626 [&](OpBuilder &builder, Location loc, ValueRange ivs, 627 ValueRange args) -> scf::ValueVector { 628 Value val; 629 if (isCOOConstant) 630 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 631 ivs, rank); 632 else 633 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 634 genAddEltCall(rewriter, op, eltType, coo, val, ind, perm); 635 return {}; 636 }); 637 // Final call to construct sparse tensor storage. 638 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 639 params[7] = coo; 640 Value dst = genNewCall(rewriter, op, params); 641 genDelCOOCall(rewriter, op, eltType, coo); 642 rewriter.replaceOp(op, dst); 643 return success(); 644 } 645 }; 646 647 /// Sparse conversion rule for the release operator. 648 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 649 public: 650 using OpConversionPattern::OpConversionPattern; 651 LogicalResult 652 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 653 ConversionPatternRewriter &rewriter) const override { 654 StringRef name = "delSparseTensor"; 655 TypeRange noTp; 656 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 657 EmitCInterface::Off); 658 rewriter.eraseOp(op); 659 return success(); 660 } 661 }; 662 663 /// Sparse conversion rule for pointer accesses. 664 class SparseTensorToPointersConverter 665 : public OpConversionPattern<ToPointersOp> { 666 public: 667 using OpConversionPattern::OpConversionPattern; 668 LogicalResult 669 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 Type resType = op.getType(); 672 Type ptrType = resType.cast<ShapedType>().getElementType(); 673 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 674 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 675 EmitCInterface::On); 676 return success(); 677 } 678 }; 679 680 /// Sparse conversion rule for index accesses. 681 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 682 public: 683 using OpConversionPattern::OpConversionPattern; 684 LogicalResult 685 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 686 ConversionPatternRewriter &rewriter) const override { 687 Type resType = op.getType(); 688 Type indType = resType.cast<ShapedType>().getElementType(); 689 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 690 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 691 EmitCInterface::On); 692 return success(); 693 } 694 }; 695 696 /// Sparse conversion rule for value accesses. 697 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 698 public: 699 using OpConversionPattern::OpConversionPattern; 700 LogicalResult 701 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 702 ConversionPatternRewriter &rewriter) const override { 703 Type resType = op.getType(); 704 Type eltType = resType.cast<ShapedType>().getElementType(); 705 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 706 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 707 EmitCInterface::On); 708 return success(); 709 } 710 }; 711 712 /// Sparse conversion rule for tensor rematerialization. 713 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 714 public: 715 using OpConversionPattern::OpConversionPattern; 716 LogicalResult 717 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 718 ConversionPatternRewriter &rewriter) const override { 719 if (op.hasInserts()) { 720 // Finalize any pending insertions. 721 StringRef name = "endInsert"; 722 TypeRange noTp; 723 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 724 EmitCInterface::Off); 725 } 726 rewriter.replaceOp(op, adaptor.getOperands()); 727 return success(); 728 } 729 }; 730 731 /// Sparse conversion rule for inserting in lexicographic index order. 732 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 733 public: 734 using OpConversionPattern::OpConversionPattern; 735 LogicalResult 736 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 737 ConversionPatternRewriter &rewriter) const override { 738 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 739 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 740 TypeRange noTp; 741 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 742 EmitCInterface::On); 743 return success(); 744 } 745 }; 746 747 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 748 public: 749 using OpConversionPattern::OpConversionPattern; 750 LogicalResult 751 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const override { 753 Location loc = op->getLoc(); 754 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 755 Type eltType = srcType.getElementType(); 756 Type boolType = rewriter.getIntegerType(1); 757 Type idxType = rewriter.getIndexType(); 758 // All initialization should be done on entry of the loop nest. 759 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 760 // Determine the size for access expansion. 761 auto enc = getSparseTensorEncoding(srcType); 762 Value src = adaptor.getOperands()[0]; 763 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 764 // Allocate temporary stack buffers for values, filled-switch, and indices. 765 Value values = genAlloca(rewriter, loc, sz, eltType); 766 Value filled = genAlloca(rewriter, loc, sz, boolType); 767 Value indices = genAlloca(rewriter, loc, sz, idxType); 768 Value zero = constantZero(rewriter, loc, idxType); 769 // Reset the values/filled-switch to all-zero/false. Note that this 770 // introduces an O(N) operation into the computation, but this reset 771 // operation is amortized over the innermost loops for the access 772 // pattern expansion. 773 rewriter.create<linalg::FillOp>( 774 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 775 ValueRange{values}); 776 rewriter.create<linalg::FillOp>( 777 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 778 ValueRange{filled}); 779 // Replace expansion op with these buffers and initial index. 780 assert(op.getNumResults() == 4); 781 rewriter.replaceOp(op, {values, filled, indices, zero}); 782 return success(); 783 } 784 }; 785 786 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 787 public: 788 using OpConversionPattern::OpConversionPattern; 789 LogicalResult 790 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 791 ConversionPatternRewriter &rewriter) const override { 792 // Note that this method call resets the values/filled-switch back to 793 // all-zero/false by only iterating over the set elements, so the 794 // complexity remains proportional to the sparsity of the expanded 795 // access pattern. 796 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 797 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 798 TypeRange noTp; 799 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 800 EmitCInterface::On); 801 return success(); 802 } 803 }; 804 805 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 806 public: 807 using OpConversionPattern::OpConversionPattern; 808 LogicalResult 809 matchAndRewrite(OutOp op, OpAdaptor adaptor, 810 ConversionPatternRewriter &rewriter) const override { 811 Location loc = op->getLoc(); 812 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 813 // Convert to default permuted COO. 814 Value src = adaptor.getOperands()[0]; 815 auto encSrc = getSparseTensorEncoding(srcType); 816 SmallVector<Value, 4> sizes; 817 SmallVector<Value, 8> params; 818 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 819 auto enc = SparseTensorEncodingAttr::get( 820 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 821 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 822 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 823 Value coo = genNewCall(rewriter, op, params); 824 // Then output the tensor to external file with indices in the externally 825 // visible lexicographic index order. A sort is required if the source was 826 // not in that order yet (note that the sort can be dropped altogether if 827 // external format does not care about the order at all, but here we assume 828 // it does). 829 bool sort = 830 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 831 params.clear(); 832 params.push_back(coo); 833 params.push_back(adaptor.getOperands()[1]); 834 params.push_back(constantI1(rewriter, loc, sort)); 835 Type eltType = srcType.getElementType(); 836 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 837 TypeRange noTp; 838 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 839 genDelCOOCall(rewriter, op, eltType, coo); 840 rewriter.eraseOp(op); 841 return success(); 842 } 843 }; 844 845 } // namespace 846 847 //===----------------------------------------------------------------------===// 848 // Public method for populating conversion rules. 849 //===----------------------------------------------------------------------===// 850 851 /// Populates the given patterns list with conversion rules required for 852 /// the sparsification of linear algebra operations. 853 void mlir::populateSparseTensorConversionPatterns( 854 TypeConverter &typeConverter, RewritePatternSet &patterns, 855 const SparseTensorConversionOptions &options) { 856 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 857 SparseCastConverter, SparseTensorNewConverter, 858 SparseTensorInitConverter, SparseTensorReleaseConverter, 859 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 860 SparseTensorToValuesConverter, SparseTensorLoadConverter, 861 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 862 SparseTensorCompressConverter, SparseTensorOutConverter>( 863 typeConverter, patterns.getContext()); 864 patterns.add<SparseTensorConvertConverter>(typeConverter, 865 patterns.getContext(), options); 866 } 867