1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 37 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 38 enum class EmitCInterface : bool { Off = false, On = true }; 39 40 //===----------------------------------------------------------------------===// 41 // Helper methods. 42 //===----------------------------------------------------------------------===// 43 44 /// Returns the equivalent of `void*` for opaque arguments to the 45 /// execution engine. 46 static Type getOpaquePointerType(PatternRewriter &rewriter) { 47 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 48 } 49 50 /// Returns a function reference (first hit also inserts into module). Sets 51 /// the "_emit_c_interface" on the function declaration when requested, 52 /// so that LLVM lowering generates a wrapper function that takes care 53 /// of ABI complications with passing in and returning MemRefs to C functions. 54 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 55 TypeRange resultType, ValueRange operands, 56 EmitCInterface emitCInterface) { 57 MLIRContext *context = op->getContext(); 58 auto module = op->getParentOfType<ModuleOp>(); 59 auto result = SymbolRefAttr::get(context, name); 60 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 61 if (!func) { 62 OpBuilder moduleBuilder(module.getBodyRegion()); 63 func = moduleBuilder.create<func::FuncOp>( 64 op->getLoc(), name, 65 FunctionType::get(context, operands.getTypes(), resultType)); 66 func.setPrivate(); 67 if (static_cast<bool>(emitCInterface)) 68 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 69 } 70 return result; 71 } 72 73 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 74 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 75 StringRef name, TypeRange resultType, 76 ValueRange operands, 77 EmitCInterface emitCInterface) { 78 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 79 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 80 } 81 82 /// Replaces the `op` with a `CallOp` to the function reference returned 83 /// by `getFunc()`. 84 static func::CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, 85 Operation *op, StringRef name, 86 TypeRange resultType, 87 ValueRange operands, 88 EmitCInterface emitCInterface) { 89 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 90 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 91 operands); 92 } 93 94 /// Generates dimension size call. 95 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 96 SparseTensorEncodingAttr &enc, Value src, 97 int64_t idx) { 98 // Permute the index according to an optional dimension ordering. 99 if (AffineMap p = enc.getDimOrdering()) 100 idx = p.getPermutedPosition(idx); 101 // Generate the call. 102 StringRef name = "sparseDimSize"; 103 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 104 Type iTp = rewriter.getIndexType(); 105 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 106 .getResult(0); 107 } 108 109 /// Generates a call into the "swiss army knife" method of the sparse runtime 110 /// support library for materializing sparse tensors into the computation. 111 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 112 ArrayRef<Value> params) { 113 StringRef name = "newSparseTensor"; 114 Type pTp = getOpaquePointerType(rewriter); 115 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 116 .getResult(0); 117 } 118 119 /// Populates given sizes array from type. 120 static void sizesFromType(ConversionPatternRewriter &rewriter, 121 SmallVector<Value, 4> &sizes, Location loc, 122 ShapedType stp) { 123 auto shape = stp.getShape(); 124 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 125 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 126 sizes.push_back(constantIndex(rewriter, loc, s)); 127 } 128 } 129 130 /// Populates given sizes array from source. 131 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 132 SmallVector<Value, 4> &sizes, Location loc, 133 Value src) { 134 unsigned rank = src.getType().cast<ShapedType>().getRank(); 135 for (unsigned i = 0; i < rank; i++) 136 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 137 } 138 139 /// Populates given sizes array from type (for static sizes) and from 140 /// an already converted into opague pointer source (for dynamic sizes). 141 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 142 SmallVector<Value, 4> &sizes, Operation *op, 143 SparseTensorEncodingAttr &enc, ShapedType stp, 144 Value src) { 145 Location loc = op->getLoc(); 146 auto shape = stp.getShape(); 147 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 148 if (shape[i] == ShapedType::kDynamicSize) 149 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 150 else 151 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 152 } 153 154 /// Generates an uninitialized temporary buffer of the given size and 155 /// type, but returns it as type `memref<? x $tp>` (rather than as type 156 /// `memref<$sz x $tp>`). 157 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 158 Value sz, Type tp) { 159 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 160 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 161 } 162 163 /// Generates an uninitialized buffer of the given size and type, 164 /// but returns it as type `memref<? x $tp>` (rather than as type 165 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 166 /// this buffer must be explicitly deallocated by client. 167 static Value genAlloc(ConversionPatternRewriter &rewriter, Location loc, 168 Value sz, Type tp) { 169 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 170 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 171 } 172 173 /// Generates an uninitialized temporary buffer of the given size and 174 /// type, but returns it as type `memref<? x $tp>` (rather than as type 175 /// `memref<$sz x $tp>`). 176 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 177 unsigned sz, Type tp) { 178 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 179 } 180 181 /// Generates an uninitialized temporary buffer with room for one value 182 /// of the given type, and returns the `memref<$tp>`. 183 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 184 Type tp) { 185 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 186 } 187 188 /// Generates a temporary buffer of the given type and given contents. 189 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 190 ValueRange values) { 191 unsigned sz = values.size(); 192 assert(sz >= 1); 193 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 194 for (unsigned i = 0; i < sz; i++) { 195 Value idx = constantIndex(rewriter, loc, i); 196 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 197 } 198 return buffer; 199 } 200 201 /// Populates parameters required to call the "swiss army knife" method of the 202 /// sparse runtime support library for materializing sparse tensors into the 203 /// computation. 204 static void newParams(ConversionPatternRewriter &rewriter, 205 SmallVector<Value, 8> ¶ms, Operation *op, 206 ShapedType stp, SparseTensorEncodingAttr &enc, 207 Action action, ValueRange szs, Value ptr = Value()) { 208 Location loc = op->getLoc(); 209 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 210 unsigned sz = dlt.size(); 211 // Sparsity annotations. 212 SmallVector<Value, 4> attrs; 213 for (unsigned i = 0; i < sz; i++) 214 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 215 params.push_back(genBuffer(rewriter, loc, attrs)); 216 // Dimension sizes array of the enveloping tensor. Useful for either 217 // verification of external data, or for construction of internal data. 218 params.push_back(genBuffer(rewriter, loc, szs)); 219 // Dimension order permutation array. This is the "identity" permutation by 220 // default, or otherwise the "reverse" permutation of a given ordering, so 221 // that indices can be mapped quickly to the right position. 222 SmallVector<Value, 4> rev(sz); 223 if (AffineMap p = enc.getDimOrdering()) { 224 for (unsigned i = 0; i < sz; i++) 225 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 226 } else { 227 for (unsigned i = 0; i < sz; i++) 228 rev[i] = constantIndex(rewriter, loc, i); 229 } 230 params.push_back(genBuffer(rewriter, loc, rev)); 231 // Secondary and primary types encoding. 232 Type elemTp = stp.getElementType(); 233 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 234 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 235 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 236 // User action. 237 params.push_back(constantAction(rewriter, loc, action)); 238 // Payload pointer. 239 if (!ptr) 240 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 241 params.push_back(ptr); 242 } 243 244 /// Generates the code to read the value from tensor[ivs], and conditionally 245 /// stores the indices ivs to the memory in ind. The generated code looks like 246 /// the following and the insertion point after this routine is inside the 247 /// if-then branch behind the assignment to ind. This is to ensure that the 248 /// addEltX call generated after is inside the if-then branch. 249 /// if (tensor[ivs]!=0) { 250 /// ind = ivs 251 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 252 Location loc, Value tensor, Value ind, 253 ValueRange ivs) { 254 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 255 Value cond = genIsNonzero(rewriter, loc, val); 256 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 257 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 258 unsigned i = 0; 259 for (auto iv : ivs) { 260 Value idx = constantIndex(rewriter, loc, i++); 261 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 262 } 263 return val; 264 } 265 266 /// Generates a call to release/delete a `SparseTensorCOO`. 267 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 268 Value coo) { 269 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 270 TypeRange noTp; 271 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 272 } 273 274 /// Generates a call that adds one element to a coordinate scheme. 275 /// In particular, this generates code like the following: 276 /// val = a[i1,..,ik]; 277 /// if val != 0 278 /// t->add(val, [i1,..,ik], [p1,..,pk]); 279 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 280 Type eltType, Value ptr, Value val, Value ind, 281 Value perm) { 282 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 283 SmallVector<Value, 4> params{ptr, val, ind, perm}; 284 Type pTp = getOpaquePointerType(rewriter); 285 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 286 } 287 288 /// Generates a call to `iter->getNext()`. If there is a next element, 289 /// then it is copied into the out-parameters `ind` and `elemPtr`, 290 /// and the return value is true. If there isn't a next element, then 291 /// the memory for `iter` is freed and the return value is false. 292 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 293 Value iter, Value ind, Value elemPtr) { 294 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 295 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 296 SmallVector<Value, 3> params{iter, ind, elemPtr}; 297 Type i1 = rewriter.getI1Type(); 298 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 299 .getResult(0); 300 } 301 302 /// If the tensor is a sparse constant, generates and returns the pair of 303 /// the constants for the indices and the values. 304 static Optional<std::pair<Value, Value>> 305 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 306 Value tensor) { 307 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 308 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 309 DenseElementsAttr indicesAttr = attr.getIndices(); 310 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 311 DenseElementsAttr valuesAttr = attr.getValues(); 312 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 313 return std::make_pair(indices, values); 314 } 315 } 316 return {}; 317 } 318 319 /// Generates the code to copy the index at indices[ivs] to ind, and return 320 /// the value at value[ivs]. 321 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 322 Location loc, Value indices, 323 Value values, Value ind, ValueRange ivs, 324 unsigned rank) { 325 for (unsigned i = 0; i < rank; i++) { 326 Value idx = constantIndex(rewriter, loc, i); 327 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 328 ValueRange{ivs[0], idx}); 329 val = 330 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val); 331 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 332 } 333 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 334 } 335 336 /// Generates code to allocate a tensor of the given type, and zero 337 /// initialize it. If the tensor type has any dynamic sizes, then the 338 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 339 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 340 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 341 RankedTensorType tensorTp, ValueRange sizes) { 342 Type elemTp = tensorTp.getElementType(); 343 auto shape = tensorTp.getShape(); 344 auto memTp = MemRefType::get(shape, elemTp); 345 SmallVector<Value> dynamicSizes; 346 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 347 if (shape[i] == ShapedType::kDynamicSize) 348 dynamicSizes.push_back(sizes[i]); 349 } 350 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 351 Value zero = constantZero(rewriter, loc, elemTp); 352 rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 353 return mem; 354 } 355 356 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 357 /// the tensor created by allocDenseTensor(). The `rank` is the rank 358 /// of the `tensor` and the length of `ind`. 359 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 360 Location loc, Value elemPtr, 361 Value tensor, unsigned rank, 362 Value ind) { 363 SmallVector<Value, 4> ivs; 364 ivs.reserve(rank); 365 for (unsigned i = 0; i < rank; i++) { 366 Value idx = constantIndex(rewriter, loc, i); 367 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 368 } 369 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 370 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Conversion rules. 375 //===----------------------------------------------------------------------===// 376 377 /// Sparse conversion rule for returns. 378 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 379 public: 380 using OpConversionPattern::OpConversionPattern; 381 LogicalResult 382 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 383 ConversionPatternRewriter &rewriter) const override { 384 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 385 return success(); 386 } 387 }; 388 389 /// Sparse conversion rule for dimension accesses. 390 class SparseTensorToDimSizeConverter 391 : public OpConversionPattern<tensor::DimOp> { 392 public: 393 using OpConversionPattern::OpConversionPattern; 394 LogicalResult 395 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 396 ConversionPatternRewriter &rewriter) const override { 397 // Only rewrite annotated DimOp with constant index. 398 auto enc = getSparseTensorEncoding(op.source().getType()); 399 if (!enc) 400 return failure(); 401 Optional<int64_t> index = op.getConstantIndex(); 402 if (!index.hasValue()) 403 return failure(); 404 // Generate the call. 405 Value src = adaptor.getOperands()[0]; 406 int64_t idx = index.getValue(); 407 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 408 return success(); 409 } 410 }; 411 412 /// Sparse conversion rule for trivial tensor casts. 413 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 414 using OpConversionPattern::OpConversionPattern; 415 LogicalResult 416 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 417 ConversionPatternRewriter &rewriter) const override { 418 // Only rewrite identically annotated source/dest. 419 auto encDst = getSparseTensorEncoding(op.getType()); 420 auto encSrc = getSparseTensorEncoding(op.source().getType()); 421 if (!encDst || encDst != encSrc) 422 return failure(); 423 rewriter.replaceOp(op, adaptor.getOperands()); 424 return success(); 425 } 426 }; 427 428 /// Sparse conversion rule for the new operator. 429 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 430 using OpConversionPattern::OpConversionPattern; 431 LogicalResult 432 matchAndRewrite(NewOp op, OpAdaptor adaptor, 433 ConversionPatternRewriter &rewriter) const override { 434 Type resType = op.getType(); 435 auto enc = getSparseTensorEncoding(resType); 436 if (!enc) 437 return failure(); 438 // Generate the call to construct tensor from ptr. The sizes are 439 // inferred from the result type of the new operator. 440 SmallVector<Value, 4> sizes; 441 SmallVector<Value, 8> params; 442 ShapedType stp = resType.cast<ShapedType>(); 443 sizesFromType(rewriter, sizes, op.getLoc(), stp); 444 Value ptr = adaptor.getOperands()[0]; 445 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 446 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 447 return success(); 448 } 449 }; 450 451 /// Sparse conversion rule for the init operator. 452 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 453 using OpConversionPattern::OpConversionPattern; 454 LogicalResult 455 matchAndRewrite(InitOp op, OpAdaptor adaptor, 456 ConversionPatternRewriter &rewriter) const override { 457 Type resType = op.getType(); 458 auto enc = getSparseTensorEncoding(resType); 459 if (!enc) 460 return failure(); 461 // Generate the call to construct empty tensor. The sizes are 462 // explicitly defined by the arguments to the init operator. 463 SmallVector<Value, 8> params; 464 ShapedType stp = resType.cast<ShapedType>(); 465 newParams(rewriter, params, op, stp, enc, Action::kEmpty, 466 adaptor.getOperands()); 467 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 468 return success(); 469 } 470 }; 471 472 /// Sparse conversion rule for the convert operator. 473 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 474 /// Options to control sparse code generation. 475 SparseTensorConversionOptions options; 476 477 public: 478 using OpConversionPattern::OpConversionPattern; 479 SparseTensorConvertConverter(MLIRContext *context, 480 SparseTensorConversionOptions o) 481 : OpConversionPattern<ConvertOp>(context), options(o) {} 482 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 483 SparseTensorConversionOptions o) 484 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 485 486 LogicalResult 487 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const override { 489 Location loc = op->getLoc(); 490 Type resType = op.getType(); 491 Type srcType = op.source().getType(); 492 auto encDst = getSparseTensorEncoding(resType); 493 auto encSrc = getSparseTensorEncoding(srcType); 494 Value src = adaptor.getOperands()[0]; 495 if (encDst && encSrc) { 496 // This is a sparse => sparse conversion, which is handled as follows: 497 // t = src->toCOO(); ; src to COO in dst order 498 // dst = newSparseTensor(t) 499 // Using the coordinate scheme as an intermediate does not always 500 // yield the fastest conversion but avoids the need for a full 501 // O(N^2) conversion matrix. 502 if (encDst == encSrc) { 503 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 504 return success(); 505 } 506 SmallVector<Value, 4> sizes; 507 SmallVector<Value, 8> params; 508 ShapedType stp = srcType.cast<ShapedType>(); 509 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 510 // Set up encoding with right mix of src and dst so that the two 511 // method calls can share most parameters, while still providing 512 // the correct sparsity information to either of them. 513 auto enc = SparseTensorEncodingAttr::get( 514 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 515 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 516 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 517 Value coo = genNewCall(rewriter, op, params); 518 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 519 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 520 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 521 params[7] = coo; 522 Value dst = genNewCall(rewriter, op, params); 523 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 524 rewriter.replaceOp(op, dst); 525 return success(); 526 } 527 if (!encDst && encSrc) { 528 // This is sparse => dense conversion, which is handled as follows: 529 // dst = new Tensor(0); 530 // iter = src->toCOO(); 531 // iter->startIterator(); 532 // while (elem = iter->getNext()) { 533 // dst[elem.indices] = elem.value; 534 // } 535 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 536 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 537 unsigned rank = dstTensorTp.getRank(); 538 Type elemTp = dstTensorTp.getElementType(); 539 // Fabricate a no-permutation encoding for newParams(). 540 // The pointer/index types must be those of `src`. 541 // The dimLevelTypes aren't actually used by Action::kToIterator. 542 encDst = SparseTensorEncodingAttr::get( 543 op->getContext(), 544 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 545 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 546 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 547 SmallVector<Value, 4> sizes; 548 SmallVector<Value, 8> params; 549 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 550 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 551 sizes, src); 552 Value iter = genNewCall(rewriter, op, params); 553 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 554 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 555 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 556 SmallVector<Value> noArgs; 557 SmallVector<Type> noTypes; 558 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 559 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 560 rewriter.setInsertionPointToEnd(before); 561 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 562 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 563 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 564 rewriter.setInsertionPointToStart(after); 565 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 566 rewriter.create<scf::YieldOp>(loc); 567 rewriter.setInsertionPointAfter(whileOp); 568 genDelCOOCall(rewriter, op, elemTp, iter); 569 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 570 return success(); 571 } 572 if (!encDst && !encSrc) { 573 // dense => dense 574 return failure(); 575 } 576 // This is a dense => sparse conversion or a sparse constant in COO => 577 // sparse conversion, which is handled as follows: 578 // t = newSparseCOO() 579 // ...code to fill the COO tensor t... 580 // s = newSparseTensor(t) 581 // 582 // To fill the COO tensor from a dense tensor: 583 // for i1 in dim1 584 // .. 585 // for ik in dimk 586 // val = a[i1,..,ik] 587 // if val != 0 588 // t->add(val, [i1,..,ik], [p1,..,pk]) 589 // 590 // To fill the COO tensor from a sparse constant in COO format: 591 // for i in range(NNZ) 592 // val = values[i] 593 // [i1,..,ik] = indices[i] 594 // t->add(val, [i1,..,ik], [p1,..,pk]) 595 // 596 // Note that the dense tensor traversal code is actually implemented 597 // using MLIR IR to avoid having to expose too much low-level 598 // memref traversal details to the runtime support library. 599 // Also note that the code below only generates the "new" ops and 600 // the loop-nest per se; whereas the entire body of the innermost 601 // loop is generated by genAddElt(). 602 ShapedType stp = resType.cast<ShapedType>(); 603 unsigned rank = stp.getRank(); 604 SmallVector<Value, 4> sizes; 605 SmallVector<Value, 8> params; 606 sizesFromSrc(rewriter, sizes, loc, src); 607 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 608 Value coo = genNewCall(rewriter, op, params); 609 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 610 Value perm = params[2]; 611 SmallVector<Value> lo; 612 SmallVector<Value> hi; 613 SmallVector<Value> st; 614 Value zero = constantIndex(rewriter, loc, 0); 615 Value one = constantIndex(rewriter, loc, 1); 616 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 617 bool isCOOConstant = indicesValues.hasValue(); 618 Value indices; 619 Value values; 620 if (isCOOConstant) { 621 indices = indicesValues->first; 622 values = indicesValues->second; 623 lo.push_back(zero); 624 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 625 st.push_back(one); 626 } else { 627 for (unsigned i = 0; i < rank; i++) { 628 lo.push_back(zero); 629 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 630 st.push_back(one); 631 } 632 } 633 Type eltType = stp.getElementType(); 634 scf::buildLoopNest( 635 rewriter, op.getLoc(), lo, hi, st, {}, 636 [&](OpBuilder &builder, Location loc, ValueRange ivs, 637 ValueRange args) -> scf::ValueVector { 638 Value val; 639 if (isCOOConstant) 640 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 641 ivs, rank); 642 else 643 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 644 genAddEltCall(rewriter, op, eltType, coo, val, ind, perm); 645 return {}; 646 }); 647 // Final call to construct sparse tensor storage. 648 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 649 params[7] = coo; 650 Value dst = genNewCall(rewriter, op, params); 651 genDelCOOCall(rewriter, op, eltType, coo); 652 rewriter.replaceOp(op, dst); 653 return success(); 654 } 655 }; 656 657 /// Sparse conversion rule for the release operator. 658 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 LogicalResult 662 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 663 ConversionPatternRewriter &rewriter) const override { 664 StringRef name = "delSparseTensor"; 665 TypeRange noTp; 666 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 667 EmitCInterface::Off); 668 rewriter.eraseOp(op); 669 return success(); 670 } 671 }; 672 673 /// Sparse conversion rule for pointer accesses. 674 class SparseTensorToPointersConverter 675 : public OpConversionPattern<ToPointersOp> { 676 public: 677 using OpConversionPattern::OpConversionPattern; 678 LogicalResult 679 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 680 ConversionPatternRewriter &rewriter) const override { 681 Type resType = op.getType(); 682 Type ptrType = resType.cast<ShapedType>().getElementType(); 683 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 684 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 685 EmitCInterface::On); 686 return success(); 687 } 688 }; 689 690 /// Sparse conversion rule for index accesses. 691 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 692 public: 693 using OpConversionPattern::OpConversionPattern; 694 LogicalResult 695 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 696 ConversionPatternRewriter &rewriter) const override { 697 Type resType = op.getType(); 698 Type indType = resType.cast<ShapedType>().getElementType(); 699 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 700 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 701 EmitCInterface::On); 702 return success(); 703 } 704 }; 705 706 /// Sparse conversion rule for value accesses. 707 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 708 public: 709 using OpConversionPattern::OpConversionPattern; 710 LogicalResult 711 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 712 ConversionPatternRewriter &rewriter) const override { 713 Type resType = op.getType(); 714 Type eltType = resType.cast<ShapedType>().getElementType(); 715 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 716 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 717 EmitCInterface::On); 718 return success(); 719 } 720 }; 721 722 /// Sparse conversion rule for tensor rematerialization. 723 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 724 public: 725 using OpConversionPattern::OpConversionPattern; 726 LogicalResult 727 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 if (op.hasInserts()) { 730 // Finalize any pending insertions. 731 StringRef name = "endInsert"; 732 TypeRange noTp; 733 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 734 EmitCInterface::Off); 735 } 736 rewriter.replaceOp(op, adaptor.getOperands()); 737 return success(); 738 } 739 }; 740 741 /// Sparse conversion rule for inserting in lexicographic index order. 742 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 743 public: 744 using OpConversionPattern::OpConversionPattern; 745 LogicalResult 746 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 747 ConversionPatternRewriter &rewriter) const override { 748 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 749 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 750 TypeRange noTp; 751 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 752 EmitCInterface::On); 753 return success(); 754 } 755 }; 756 757 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 758 public: 759 using OpConversionPattern::OpConversionPattern; 760 LogicalResult 761 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 762 ConversionPatternRewriter &rewriter) const override { 763 Location loc = op->getLoc(); 764 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 765 Type eltType = srcType.getElementType(); 766 Type boolType = rewriter.getIntegerType(1); 767 Type idxType = rewriter.getIndexType(); 768 // All initialization should be done on entry of the loop nest. 769 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 770 // Determine the size for access expansion. 771 auto enc = getSparseTensorEncoding(srcType); 772 Value src = adaptor.getOperands()[0]; 773 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 774 // Allocate temporary buffers for values, filled-switch, and indices. 775 // We do not use stack buffers for this, since the expanded size may 776 // be rather large (as it envelops a single expanded dense dimension). 777 Value values = genAlloc(rewriter, loc, sz, eltType); 778 Value filled = genAlloc(rewriter, loc, sz, boolType); 779 Value indices = genAlloc(rewriter, loc, sz, idxType); 780 Value zero = constantZero(rewriter, loc, idxType); 781 // Reset the values/filled-switch to all-zero/false. Note that this 782 // introduces an O(N) operation into the computation, but this reset 783 // operation is amortized over the innermost loops for the access 784 // pattern expansion. As noted in the operation doc, we would like 785 // to amortize this setup cost even between kernels. 786 rewriter.create<linalg::FillOp>( 787 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 788 ValueRange{values}); 789 rewriter.create<linalg::FillOp>( 790 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 791 ValueRange{filled}); 792 // Replace expansion op with these buffers and initial index. 793 assert(op.getNumResults() == 4); 794 rewriter.replaceOp(op, {values, filled, indices, zero}); 795 return success(); 796 } 797 }; 798 799 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 800 public: 801 using OpConversionPattern::OpConversionPattern; 802 LogicalResult 803 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 804 ConversionPatternRewriter &rewriter) const override { 805 Location loc = op->getLoc(); 806 // Note that this method call resets the values/filled-switch back to 807 // all-zero/false by only iterating over the set elements, so the 808 // complexity remains proportional to the sparsity of the expanded 809 // access pattern. 810 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 811 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 812 TypeRange noTp; 813 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 814 EmitCInterface::On); 815 // Deallocate the buffers on exit of the loop nest. 816 Operation *parent = op; 817 for (; isa<scf::ForOp>(parent->getParentOp()) || 818 isa<scf::WhileOp>(parent->getParentOp()) || 819 isa<scf::ParallelOp>(parent->getParentOp()) || 820 isa<scf::IfOp>(parent->getParentOp()); 821 parent = parent->getParentOp()) 822 ; 823 rewriter.setInsertionPointAfter(parent); 824 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]); 825 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]); 826 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]); 827 return success(); 828 } 829 }; 830 831 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 832 public: 833 using OpConversionPattern::OpConversionPattern; 834 LogicalResult 835 matchAndRewrite(OutOp op, OpAdaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 Location loc = op->getLoc(); 838 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 839 // Convert to default permuted COO. 840 Value src = adaptor.getOperands()[0]; 841 auto encSrc = getSparseTensorEncoding(srcType); 842 SmallVector<Value, 4> sizes; 843 SmallVector<Value, 8> params; 844 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 845 auto enc = SparseTensorEncodingAttr::get( 846 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 847 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 848 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 849 Value coo = genNewCall(rewriter, op, params); 850 // Then output the tensor to external file with indices in the externally 851 // visible lexicographic index order. A sort is required if the source was 852 // not in that order yet (note that the sort can be dropped altogether if 853 // external format does not care about the order at all, but here we assume 854 // it does). 855 bool sort = 856 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 857 params.clear(); 858 params.push_back(coo); 859 params.push_back(adaptor.getOperands()[1]); 860 params.push_back(constantI1(rewriter, loc, sort)); 861 Type eltType = srcType.getElementType(); 862 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 863 TypeRange noTp; 864 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 865 genDelCOOCall(rewriter, op, eltType, coo); 866 rewriter.eraseOp(op); 867 return success(); 868 } 869 }; 870 871 } // namespace 872 873 //===----------------------------------------------------------------------===// 874 // Public method for populating conversion rules. 875 //===----------------------------------------------------------------------===// 876 877 /// Populates the given patterns list with conversion rules required for 878 /// the sparsification of linear algebra operations. 879 void mlir::populateSparseTensorConversionPatterns( 880 TypeConverter &typeConverter, RewritePatternSet &patterns, 881 const SparseTensorConversionOptions &options) { 882 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 883 SparseCastConverter, SparseTensorNewConverter, 884 SparseTensorInitConverter, SparseTensorReleaseConverter, 885 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 886 SparseTensorToValuesConverter, SparseTensorLoadConverter, 887 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 888 SparseTensorCompressConverter, SparseTensorOutConverter>( 889 typeConverter, patterns.getContext()); 890 patterns.add<SparseTensorConvertConverter>(typeConverter, 891 patterns.getContext(), options); 892 } 893