1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/SCF/SCF.h" 24 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 25 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 26 #include "mlir/Dialect/StandardOps/IR/Ops.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 37 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 38 enum class EmitCInterface : bool { Off = false, On = true }; 39 40 //===----------------------------------------------------------------------===// 41 // Helper methods. 42 //===----------------------------------------------------------------------===// 43 44 /// Returns the equivalent of `void*` for opaque arguments to the 45 /// execution engine. 46 static Type getOpaquePointerType(PatternRewriter &rewriter) { 47 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 48 } 49 50 /// Returns a function reference (first hit also inserts into module). Sets 51 /// the "_emit_c_interface" on the function declaration when requested, 52 /// so that LLVM lowering generates a wrapper function that takes care 53 /// of ABI complications with passing in and returning MemRefs to C functions. 54 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 55 TypeRange resultType, ValueRange operands, 56 EmitCInterface emitCInterface) { 57 MLIRContext *context = op->getContext(); 58 auto module = op->getParentOfType<ModuleOp>(); 59 auto result = SymbolRefAttr::get(context, name); 60 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 61 if (!func) { 62 OpBuilder moduleBuilder(module.getBodyRegion()); 63 func = moduleBuilder.create<FuncOp>( 64 op->getLoc(), name, 65 FunctionType::get(context, operands.getTypes(), resultType)); 66 func.setPrivate(); 67 if (static_cast<bool>(emitCInterface)) 68 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 69 } 70 return result; 71 } 72 73 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 74 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 75 TypeRange resultType, ValueRange operands, 76 EmitCInterface emitCInterface) { 77 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 78 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 79 } 80 81 /// Replaces the `op` with a `CallOp` to the function reference returned 82 /// by `getFunc()`. 83 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 84 StringRef name, TypeRange resultType, 85 ValueRange operands, 86 EmitCInterface emitCInterface) { 87 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 88 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 89 } 90 91 /// Generates dimension size call. 92 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 93 SparseTensorEncodingAttr &enc, Value src, 94 int64_t idx) { 95 // Permute the index according to an optional dimension ordering. 96 if (AffineMap p = enc.getDimOrdering()) 97 idx = p.getPermutedPosition(idx); 98 // Generate the call. 99 StringRef name = "sparseDimSize"; 100 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 101 Type iTp = rewriter.getIndexType(); 102 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 103 .getResult(0); 104 } 105 106 /// Generates a call into the "swiss army knife" method of the sparse runtime 107 /// support library for materializing sparse tensors into the computation. 108 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 109 ArrayRef<Value> params) { 110 StringRef name = "newSparseTensor"; 111 Type pTp = getOpaquePointerType(rewriter); 112 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 113 .getResult(0); 114 } 115 116 /// Populates given sizes array from type. 117 static void sizesFromType(ConversionPatternRewriter &rewriter, 118 SmallVector<Value, 4> &sizes, Location loc, 119 ShapedType stp) { 120 auto shape = stp.getShape(); 121 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 122 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 123 sizes.push_back(constantIndex(rewriter, loc, s)); 124 } 125 } 126 127 /// Populates given sizes array from source. 128 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 129 SmallVector<Value, 4> &sizes, Location loc, 130 Value src) { 131 unsigned rank = src.getType().cast<ShapedType>().getRank(); 132 for (unsigned i = 0; i < rank; i++) 133 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 134 } 135 136 /// Populates given sizes array from type (for static sizes) and from 137 /// an already converted into opague pointer source (for dynamic sizes). 138 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 139 SmallVector<Value, 4> &sizes, Operation *op, 140 SparseTensorEncodingAttr &enc, ShapedType stp, 141 Value src) { 142 Location loc = op->getLoc(); 143 auto shape = stp.getShape(); 144 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 145 if (shape[i] == ShapedType::kDynamicSize) 146 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 147 else 148 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 149 } 150 151 /// Generates an uninitialized temporary buffer of the given size and 152 /// type, but returns it as type `memref<? x $tp>` (rather than as type 153 /// `memref<$sz x $tp>`). 154 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 155 Value sz, Type tp) { 156 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 157 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 158 } 159 160 /// Generates an uninitialized temporary buffer of the given size and 161 /// type, but returns it as type `memref<? x $tp>` (rather than as type 162 /// `memref<$sz x $tp>`). 163 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 164 unsigned sz, Type tp) { 165 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 166 } 167 168 /// Generates an uninitialized temporary buffer with room for one value 169 /// of the given type, and returns the `memref<$tp>`. 170 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 171 Type tp) { 172 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 173 } 174 175 /// Generates a temporary buffer of the given type and given contents. 176 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 177 ArrayRef<Value> values) { 178 unsigned sz = values.size(); 179 assert(sz >= 1); 180 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 181 for (unsigned i = 0; i < sz; i++) { 182 Value idx = constantIndex(rewriter, loc, i); 183 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 184 } 185 return buffer; 186 } 187 188 /// Populates parameters required to call the "swiss army knife" method of the 189 /// sparse runtime support library for materializing sparse tensors into the 190 /// computation. 191 static void newParams(ConversionPatternRewriter &rewriter, 192 SmallVector<Value, 8> ¶ms, Operation *op, 193 ShapedType stp, SparseTensorEncodingAttr &enc, 194 Action action, ValueRange szs, Value ptr = Value()) { 195 Location loc = op->getLoc(); 196 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 197 unsigned sz = dlt.size(); 198 // Sparsity annotations. 199 SmallVector<Value, 4> attrs; 200 for (unsigned i = 0; i < sz; i++) 201 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 202 params.push_back(genBuffer(rewriter, loc, attrs)); 203 // Dimension sizes array of the enveloping tensor. Useful for either 204 // verification of external data, or for construction of internal data. 205 SmallVector<Value, 4> sizes; 206 for (Value s : szs) 207 sizes.push_back(s); 208 params.push_back(genBuffer(rewriter, loc, sizes)); 209 // Dimension order permutation array. This is the "identity" permutation by 210 // default, or otherwise the "reverse" permutation of a given ordering, so 211 // that indices can be mapped quickly to the right position. 212 SmallVector<Value, 4> rev(sz); 213 if (AffineMap p = enc.getDimOrdering()) { 214 for (unsigned i = 0; i < sz; i++) 215 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 216 } else { 217 for (unsigned i = 0; i < sz; i++) 218 rev[i] = constantIndex(rewriter, loc, i); 219 } 220 params.push_back(genBuffer(rewriter, loc, rev)); 221 // Secondary and primary types encoding. 222 Type elemTp = stp.getElementType(); 223 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 224 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 225 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 226 // User action. 227 params.push_back(constantAction(rewriter, loc, action)); 228 // Payload pointer. 229 if (!ptr) 230 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 231 params.push_back(ptr); 232 } 233 234 /// Generates the code to read the value from tensor[ivs], and conditionally 235 /// stores the indices ivs to the memory in ind. The generated code looks like 236 /// the following and the insertion point after this routine is inside the 237 /// if-then branch behind the assignment to ind. This is to ensure that the 238 /// addEltX call generated after is inside the if-then branch. 239 /// if (tensor[ivs]!=0) { 240 /// ind = ivs 241 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 242 Location loc, Value tensor, Value ind, 243 ValueRange ivs) { 244 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 245 Value cond = genIsNonzero(rewriter, loc, val); 246 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 247 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 248 unsigned i = 0; 249 for (auto iv : ivs) { 250 Value idx = constantIndex(rewriter, loc, i++); 251 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 252 } 253 return val; 254 } 255 256 /// Generates a call that adds one element to a coordinate scheme. 257 /// In particular, this generates code like the following: 258 /// val = a[i1,..,ik]; 259 /// if val != 0 260 /// t->add(val, [i1,..,ik], [p1,..,pk]); 261 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 262 Type eltType, Value ptr, Value val, Value ind, 263 Value perm) { 264 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 265 SmallVector<Value, 4> params{ptr, val, ind, perm}; 266 Type pTp = getOpaquePointerType(rewriter); 267 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 268 } 269 270 /// Generates a call to `iter->getNext()`. If there is a next element, 271 /// then it is copied into the out-parameters `ind` and `elemPtr`, 272 /// and the return value is true. If there isn't a next element, then 273 /// the memory for `iter` is freed and the return value is false. 274 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 275 Value iter, Value ind, Value elemPtr) { 276 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 277 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 278 SmallVector<Value, 3> params{iter, ind, elemPtr}; 279 Type i1 = rewriter.getI1Type(); 280 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 281 .getResult(0); 282 } 283 284 /// If the tensor is a sparse constant, generates and returns the pair of 285 /// the constants for the indices and the values. 286 static Optional<std::pair<Value, Value>> 287 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 288 Value tensor) { 289 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 290 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 291 DenseElementsAttr indicesAttr = attr.getIndices(); 292 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 293 DenseElementsAttr valuesAttr = attr.getValues(); 294 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 295 return std::make_pair(indices, values); 296 } 297 } 298 return {}; 299 } 300 301 /// Generates the code to copy the index at indices[ivs] to ind, and return 302 /// the value at value[ivs]. 303 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 304 Location loc, Value indices, 305 Value values, Value ind, ValueRange ivs, 306 unsigned rank) { 307 for (unsigned i = 0; i < rank; i++) { 308 Value idx = constantIndex(rewriter, loc, i); 309 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 310 ValueRange{ivs[0], idx}); 311 val = 312 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 313 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 314 } 315 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 316 } 317 318 /// Generates code to allocate a tensor of the given type, and zero 319 /// initialize it. If the tensor type has any dynamic sizes, then the 320 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 321 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 322 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 323 RankedTensorType tensorTp, ValueRange sizes) { 324 Type elemTp = tensorTp.getElementType(); 325 auto shape = tensorTp.getShape(); 326 auto memTp = MemRefType::get(shape, elemTp); 327 SmallVector<Value> dynamicSizes; 328 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 329 if (shape[i] == ShapedType::kDynamicSize) 330 dynamicSizes.push_back(sizes[i]); 331 } 332 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 333 Value zero = constantZero(rewriter, loc, elemTp); 334 rewriter.create<linalg::FillOp>(loc, zero, mem); 335 return mem; 336 } 337 338 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 339 /// the tensor created by allocDenseTensor(). The `rank` is the rank 340 /// of the `tensor` and the length of `ind`. 341 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 342 Location loc, Value elemPtr, 343 Value tensor, unsigned rank, 344 Value ind) { 345 SmallVector<Value, 4> ivs; 346 ivs.reserve(rank); 347 for (unsigned i = 0; i < rank; i++) { 348 Value idx = constantIndex(rewriter, loc, i); 349 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 350 } 351 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 352 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // Conversion rules. 357 //===----------------------------------------------------------------------===// 358 359 /// Sparse conversion rule for returns. 360 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 361 public: 362 using OpConversionPattern::OpConversionPattern; 363 LogicalResult 364 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 365 ConversionPatternRewriter &rewriter) const override { 366 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 367 return success(); 368 } 369 }; 370 371 /// Sparse conversion rule for dimension accesses. 372 class SparseTensorToDimSizeConverter 373 : public OpConversionPattern<tensor::DimOp> { 374 public: 375 using OpConversionPattern::OpConversionPattern; 376 LogicalResult 377 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 378 ConversionPatternRewriter &rewriter) const override { 379 // Only rewrite annotated DimOp with constant index. 380 auto enc = getSparseTensorEncoding(op.source().getType()); 381 if (!enc) 382 return failure(); 383 Optional<int64_t> index = op.getConstantIndex(); 384 if (!index.hasValue()) 385 return failure(); 386 // Generate the call. 387 Value src = adaptor.getOperands()[0]; 388 int64_t idx = index.getValue(); 389 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 390 return success(); 391 } 392 }; 393 394 /// Sparse conversion rule for trivial tensor casts. 395 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 396 using OpConversionPattern::OpConversionPattern; 397 LogicalResult 398 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const override { 400 // Only rewrite identically annotated source/dest. 401 auto encDst = getSparseTensorEncoding(op.getType()); 402 auto encSrc = getSparseTensorEncoding(op.source().getType()); 403 if (!encDst || encDst != encSrc) 404 return failure(); 405 rewriter.replaceOp(op, adaptor.getOperands()); 406 return success(); 407 } 408 }; 409 410 /// Sparse conversion rule for the new operator. 411 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 412 using OpConversionPattern::OpConversionPattern; 413 LogicalResult 414 matchAndRewrite(NewOp op, OpAdaptor adaptor, 415 ConversionPatternRewriter &rewriter) const override { 416 Type resType = op.getType(); 417 auto enc = getSparseTensorEncoding(resType); 418 if (!enc) 419 return failure(); 420 // Generate the call to construct tensor from ptr. The sizes are 421 // inferred from the result type of the new operator. 422 SmallVector<Value, 4> sizes; 423 SmallVector<Value, 8> params; 424 ShapedType stp = resType.cast<ShapedType>(); 425 sizesFromType(rewriter, sizes, op.getLoc(), stp); 426 Value ptr = adaptor.getOperands()[0]; 427 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 428 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 429 return success(); 430 } 431 }; 432 433 /// Sparse conversion rule for the init operator. 434 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 435 using OpConversionPattern::OpConversionPattern; 436 LogicalResult 437 matchAndRewrite(InitOp op, OpAdaptor adaptor, 438 ConversionPatternRewriter &rewriter) const override { 439 Type resType = op.getType(); 440 auto enc = getSparseTensorEncoding(resType); 441 if (!enc) 442 return failure(); 443 // Generate the call to construct empty tensor. The sizes are 444 // explicitly defined by the arguments to the init operator. 445 SmallVector<Value, 8> params; 446 ShapedType stp = resType.cast<ShapedType>(); 447 newParams(rewriter, params, op, stp, enc, Action::kEmpty, 448 adaptor.getOperands()); 449 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 450 return success(); 451 } 452 }; 453 454 /// Sparse conversion rule for the convert operator. 455 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 456 using OpConversionPattern::OpConversionPattern; 457 LogicalResult 458 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 459 ConversionPatternRewriter &rewriter) const override { 460 Location loc = op->getLoc(); 461 Type resType = op.getType(); 462 Type srcType = op.source().getType(); 463 auto encDst = getSparseTensorEncoding(resType); 464 auto encSrc = getSparseTensorEncoding(srcType); 465 Value src = adaptor.getOperands()[0]; 466 if (encDst && encSrc) { 467 // This is a sparse => sparse conversion, which is handled as follows: 468 // t = src->toCOO(); ; src to COO in dst order 469 // dst = newSparseTensor(t) 470 // Using the coordinate scheme as an intermediate does not always 471 // yield the fastest conversion but avoids the need for a full 472 // O(N^2) conversion matrix. 473 if (encDst == encSrc) { 474 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 475 return success(); 476 } 477 SmallVector<Value, 4> sizes; 478 SmallVector<Value, 8> params; 479 ShapedType stp = srcType.cast<ShapedType>(); 480 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 481 // Set up encoding with right mix of src and dst so that the two 482 // method calls can share most parameters, while still providing 483 // the correct sparsity information to either of them. 484 auto enc = SparseTensorEncodingAttr::get( 485 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 486 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 487 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 488 Value coo = genNewCall(rewriter, op, params); 489 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 490 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 491 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 492 params[7] = coo; 493 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 494 return success(); 495 } 496 if (!encDst && encSrc) { 497 // This is sparse => dense conversion, which is handled as follows: 498 // dst = new Tensor(0); 499 // iter = src->toCOO(); 500 // iter->startIterator(); 501 // while (elem = iter->getNext()) { 502 // dst[elem.indices] = elem.value; 503 // } 504 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 505 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 506 unsigned rank = dstTensorTp.getRank(); 507 Type elemTp = dstTensorTp.getElementType(); 508 // Fabricate a no-permutation encoding for newParams(). 509 // The pointer/index types must be those of `src`. 510 // The dimLevelTypes aren't actually used by Action::kToIterator. 511 encDst = SparseTensorEncodingAttr::get( 512 op->getContext(), 513 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 514 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 515 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 516 SmallVector<Value, 4> sizes; 517 SmallVector<Value, 8> params; 518 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 519 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 520 sizes, src); 521 Value iter = genNewCall(rewriter, op, params); 522 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 523 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 524 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 525 SmallVector<Value> noArgs; 526 SmallVector<Type> noTypes; 527 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 528 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 529 rewriter.setInsertionPointToEnd(before); 530 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 531 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 532 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 533 rewriter.setInsertionPointToStart(after); 534 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 535 rewriter.create<scf::YieldOp>(loc); 536 rewriter.setInsertionPointAfter(whileOp); 537 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 538 return success(); 539 } 540 if (!encDst && !encSrc) { 541 // dense => dense 542 return failure(); 543 } 544 // This is a dense => sparse conversion or a sparse constant in COO => 545 // sparse conversion, which is handled as follows: 546 // t = newSparseCOO() 547 // ...code to fill the COO tensor t... 548 // s = newSparseTensor(t) 549 // 550 // To fill the COO tensor from a dense tensor: 551 // for i1 in dim1 552 // .. 553 // for ik in dimk 554 // val = a[i1,..,ik] 555 // if val != 0 556 // t->add(val, [i1,..,ik], [p1,..,pk]) 557 // 558 // To fill the COO tensor from a sparse constant in COO format: 559 // for i in range(NNZ) 560 // val = values[i] 561 // [i1,..,ik] = indices[i] 562 // t->add(val, [i1,..,ik], [p1,..,pk]) 563 // 564 // Note that the dense tensor traversal code is actually implemented 565 // using MLIR IR to avoid having to expose too much low-level 566 // memref traversal details to the runtime support library. 567 // Also note that the code below only generates the "new" ops and 568 // the loop-nest per se; whereas the entire body of the innermost 569 // loop is generated by genAddElt(). 570 ShapedType stp = resType.cast<ShapedType>(); 571 unsigned rank = stp.getRank(); 572 SmallVector<Value, 4> sizes; 573 SmallVector<Value, 8> params; 574 sizesFromSrc(rewriter, sizes, loc, src); 575 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 576 Value ptr = genNewCall(rewriter, op, params); 577 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 578 Value perm = params[2]; 579 SmallVector<Value> lo; 580 SmallVector<Value> hi; 581 SmallVector<Value> st; 582 Value zero = constantIndex(rewriter, loc, 0); 583 Value one = constantIndex(rewriter, loc, 1); 584 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 585 bool isCOOConstant = indicesValues.hasValue(); 586 Value indices; 587 Value values; 588 if (isCOOConstant) { 589 indices = indicesValues->first; 590 values = indicesValues->second; 591 lo.push_back(zero); 592 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 593 st.push_back(one); 594 } else { 595 for (unsigned i = 0; i < rank; i++) { 596 lo.push_back(zero); 597 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 598 st.push_back(one); 599 } 600 } 601 Type eltType = stp.getElementType(); 602 scf::buildLoopNest( 603 rewriter, op.getLoc(), lo, hi, st, {}, 604 [&](OpBuilder &builder, Location loc, ValueRange ivs, 605 ValueRange args) -> scf::ValueVector { 606 Value val; 607 if (isCOOConstant) 608 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 609 ivs, rank); 610 else 611 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 612 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 613 return {}; 614 }); 615 // Final call to construct sparse tensor storage. 616 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 617 params[7] = ptr; 618 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 619 return success(); 620 } 621 }; 622 623 /// Sparse conversion rule for the release operator. 624 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 625 public: 626 using OpConversionPattern::OpConversionPattern; 627 LogicalResult 628 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 629 ConversionPatternRewriter &rewriter) const override { 630 StringRef name = "delSparseTensor"; 631 TypeRange noTp; 632 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 633 EmitCInterface::Off); 634 rewriter.eraseOp(op); 635 return success(); 636 } 637 }; 638 639 /// Sparse conversion rule for pointer accesses. 640 class SparseTensorToPointersConverter 641 : public OpConversionPattern<ToPointersOp> { 642 public: 643 using OpConversionPattern::OpConversionPattern; 644 LogicalResult 645 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 646 ConversionPatternRewriter &rewriter) const override { 647 Type resType = op.getType(); 648 Type ptrType = resType.cast<ShapedType>().getElementType(); 649 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 650 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 651 EmitCInterface::On); 652 return success(); 653 } 654 }; 655 656 /// Sparse conversion rule for index accesses. 657 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 658 public: 659 using OpConversionPattern::OpConversionPattern; 660 LogicalResult 661 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 662 ConversionPatternRewriter &rewriter) const override { 663 Type resType = op.getType(); 664 Type indType = resType.cast<ShapedType>().getElementType(); 665 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 666 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 667 EmitCInterface::On); 668 return success(); 669 } 670 }; 671 672 /// Sparse conversion rule for value accesses. 673 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 674 public: 675 using OpConversionPattern::OpConversionPattern; 676 LogicalResult 677 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 678 ConversionPatternRewriter &rewriter) const override { 679 Type resType = op.getType(); 680 Type eltType = resType.cast<ShapedType>().getElementType(); 681 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 682 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 683 EmitCInterface::On); 684 return success(); 685 } 686 }; 687 688 /// Sparse conversion rule for tensor rematerialization. 689 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 690 public: 691 using OpConversionPattern::OpConversionPattern; 692 LogicalResult 693 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 694 ConversionPatternRewriter &rewriter) const override { 695 if (op.hasInserts()) { 696 // Finalize any pending insertions. 697 StringRef name = "endInsert"; 698 TypeRange noTp; 699 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 700 EmitCInterface::Off); 701 } 702 rewriter.replaceOp(op, adaptor.getOperands()); 703 return success(); 704 } 705 }; 706 707 /// Sparse conversion rule for inserting in lexicographic index order. 708 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 709 public: 710 using OpConversionPattern::OpConversionPattern; 711 LogicalResult 712 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 713 ConversionPatternRewriter &rewriter) const override { 714 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 715 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 716 TypeRange noTp; 717 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 718 EmitCInterface::On); 719 return success(); 720 } 721 }; 722 723 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 724 public: 725 using OpConversionPattern::OpConversionPattern; 726 LogicalResult 727 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 Location loc = op->getLoc(); 730 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 731 Type eltType = srcType.getElementType(); 732 Type boolType = rewriter.getIntegerType(1); 733 Type idxType = rewriter.getIndexType(); 734 // All initialization should be done on entry of the loop nest. 735 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 736 // Determine the size for access expansion. 737 auto enc = getSparseTensorEncoding(srcType); 738 Value src = adaptor.getOperands()[0]; 739 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 740 // Allocate temporary stack buffers for values, filled-switch, and indices. 741 Value values = genAlloca(rewriter, loc, sz, eltType); 742 Value filled = genAlloca(rewriter, loc, sz, boolType); 743 Value indices = genAlloca(rewriter, loc, sz, idxType); 744 Value zero = constantZero(rewriter, loc, idxType); 745 // Reset the values/filled-switch to all-zero/false. Note that this 746 // introduces an O(N) operation into the computation, but this reset 747 // operation is amortized over the innermost loops for the access 748 // pattern expansion. 749 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType), 750 values); 751 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType), 752 filled); 753 // Replace expansion op with these buffers and initial index. 754 assert(op.getNumResults() == 4); 755 rewriter.replaceOp(op, {values, filled, indices, zero}); 756 return success(); 757 } 758 }; 759 760 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 761 public: 762 using OpConversionPattern::OpConversionPattern; 763 LogicalResult 764 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 765 ConversionPatternRewriter &rewriter) const override { 766 // Note that this method call resets the values/filled-switch back to 767 // all-zero/false by only iterating over the set elements, so the 768 // complexity remains proportional to the sparsity of the expanded 769 // access pattern. 770 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 771 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 772 TypeRange noTp; 773 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 774 EmitCInterface::On); 775 return success(); 776 } 777 }; 778 779 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 780 public: 781 using OpConversionPattern::OpConversionPattern; 782 LogicalResult 783 matchAndRewrite(OutOp op, OpAdaptor adaptor, 784 ConversionPatternRewriter &rewriter) const override { 785 Location loc = op->getLoc(); 786 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 787 // Convert to default permuted COO. 788 Value src = adaptor.getOperands()[0]; 789 auto encSrc = getSparseTensorEncoding(srcType); 790 SmallVector<Value, 4> sizes; 791 SmallVector<Value, 8> params; 792 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 793 auto enc = SparseTensorEncodingAttr::get( 794 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 795 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 796 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 797 Value coo = genNewCall(rewriter, op, params); 798 // Then output the tensor to external file with indices in the externally 799 // visible lexicographic index order. A sort is required if the source was 800 // not in that order yet (note that the sort can be dropped altogether if 801 // external format does not care about the order at all, but here we assume 802 // it does). 803 bool sort = 804 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 805 params.clear(); 806 params.push_back(coo); 807 params.push_back(adaptor.getOperands()[1]); 808 params.push_back(constantI1(rewriter, loc, sort)); 809 Type eltType = srcType.getElementType(); 810 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 811 TypeRange noTp; 812 replaceOpWithFuncCall(rewriter, op, name, noTp, params, 813 EmitCInterface::Off); 814 return success(); 815 } 816 }; 817 818 } // namespace 819 820 //===----------------------------------------------------------------------===// 821 // Public method for populating conversion rules. 822 //===----------------------------------------------------------------------===// 823 824 /// Populates the given patterns list with conversion rules required for 825 /// the sparsification of linear algebra operations. 826 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 827 RewritePatternSet &patterns) { 828 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 829 SparseCastConverter, SparseTensorNewConverter, 830 SparseTensorInitConverter, SparseTensorConvertConverter, 831 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 832 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 833 SparseTensorLoadConverter, SparseTensorLexInsertConverter, 834 SparseTensorExpandConverter, SparseTensorCompressConverter, 835 SparseTensorOutConverter>(typeConverter, patterns.getContext()); 836 } 837