1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 24 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 30 using namespace mlir; 31 using namespace mlir::sparse_tensor; 32 33 namespace { 34 35 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 36 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 37 enum class EmitCInterface : bool { Off = false, On = true }; 38 39 //===----------------------------------------------------------------------===// 40 // Helper methods. 41 //===----------------------------------------------------------------------===// 42 43 /// Returns the equivalent of `void*` for opaque arguments to the 44 /// execution engine. 45 static Type getOpaquePointerType(PatternRewriter &rewriter) { 46 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 47 } 48 49 /// Returns a function reference (first hit also inserts into module). Sets 50 /// the "_emit_c_interface" on the function declaration when requested, 51 /// so that LLVM lowering generates a wrapper function that takes care 52 /// of ABI complications with passing in and returning MemRefs to C functions. 53 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 54 TypeRange resultType, ValueRange operands, 55 EmitCInterface emitCInterface) { 56 MLIRContext *context = op->getContext(); 57 auto module = op->getParentOfType<ModuleOp>(); 58 auto result = SymbolRefAttr::get(context, name); 59 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 60 if (!func) { 61 OpBuilder moduleBuilder(module.getBodyRegion()); 62 func = moduleBuilder.create<FuncOp>( 63 op->getLoc(), name, 64 FunctionType::get(context, operands.getTypes(), resultType)); 65 func.setPrivate(); 66 if (static_cast<bool>(emitCInterface)) 67 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 68 } 69 return result; 70 } 71 72 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 73 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 74 TypeRange resultType, ValueRange operands, 75 EmitCInterface emitCInterface) { 76 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 77 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 78 } 79 80 /// Replaces the `op` with a `CallOp` to the function reference returned 81 /// by `getFunc()`. 82 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 83 StringRef name, TypeRange resultType, 84 ValueRange operands, 85 EmitCInterface emitCInterface) { 86 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 87 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 88 } 89 90 /// Generates dimension size call. 91 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 92 SparseTensorEncodingAttr &enc, Value src, 93 int64_t idx) { 94 // Permute the index according to an optional dimension ordering. 95 if (AffineMap p = enc.getDimOrdering()) 96 idx = p.getPermutedPosition(idx); 97 // Generate the call. 98 StringRef name = "sparseDimSize"; 99 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 100 Type iTp = rewriter.getIndexType(); 101 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 102 .getResult(0); 103 } 104 105 /// Generates a call into the "swiss army knife" method of the sparse runtime 106 /// support library for materializing sparse tensors into the computation. 107 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 108 ArrayRef<Value> params) { 109 StringRef name = "newSparseTensor"; 110 Type pTp = getOpaquePointerType(rewriter); 111 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 112 .getResult(0); 113 } 114 115 /// Populates given sizes array from type. 116 static void sizesFromType(ConversionPatternRewriter &rewriter, 117 SmallVector<Value, 4> &sizes, Location loc, 118 ShapedType stp) { 119 auto shape = stp.getShape(); 120 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 121 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 122 sizes.push_back(constantIndex(rewriter, loc, s)); 123 } 124 } 125 126 /// Populates given sizes array from source. 127 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 128 SmallVector<Value, 4> &sizes, Location loc, 129 Value src) { 130 unsigned rank = src.getType().cast<ShapedType>().getRank(); 131 for (unsigned i = 0; i < rank; i++) 132 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 133 } 134 135 /// Populates given sizes array from type (for static sizes) and from 136 /// an already converted into opague pointer source (for dynamic sizes). 137 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 138 SmallVector<Value, 4> &sizes, Operation *op, 139 SparseTensorEncodingAttr &enc, ShapedType stp, 140 Value src) { 141 Location loc = op->getLoc(); 142 auto shape = stp.getShape(); 143 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 144 if (shape[i] == ShapedType::kDynamicSize) 145 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 146 else 147 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 148 } 149 150 /// Generates an uninitialized temporary buffer of the given size and 151 /// type, but returns it as type `memref<? x $tp>` (rather than as type 152 /// `memref<$sz x $tp>`). 153 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 154 Value sz, Type tp) { 155 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 156 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 157 } 158 159 /// Generates an uninitialized temporary buffer of the given size and 160 /// type, but returns it as type `memref<? x $tp>` (rather than as type 161 /// `memref<$sz x $tp>`). 162 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 163 unsigned sz, Type tp) { 164 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 165 } 166 167 /// Generates an uninitialized temporary buffer with room for one value 168 /// of the given type, and returns the `memref<$tp>`. 169 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 170 Type tp) { 171 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 172 } 173 174 /// Generates a temporary buffer of the given type and given contents. 175 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 176 ArrayRef<Value> values) { 177 unsigned sz = values.size(); 178 assert(sz >= 1); 179 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 180 for (unsigned i = 0; i < sz; i++) { 181 Value idx = constantIndex(rewriter, loc, i); 182 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 183 } 184 return buffer; 185 } 186 187 /// Populates parameters required to call the "swiss army knife" method of the 188 /// sparse runtime support library for materializing sparse tensors into the 189 /// computation. 190 static void newParams(ConversionPatternRewriter &rewriter, 191 SmallVector<Value, 8> ¶ms, Operation *op, 192 SparseTensorEncodingAttr &enc, Action action, 193 ValueRange szs, Value ptr = Value()) { 194 Location loc = op->getLoc(); 195 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 196 unsigned sz = dlt.size(); 197 // Sparsity annotations. 198 SmallVector<Value, 4> attrs; 199 for (unsigned i = 0; i < sz; i++) 200 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 201 params.push_back(genBuffer(rewriter, loc, attrs)); 202 // Dimension sizes array of the enveloping tensor. Useful for either 203 // verification of external data, or for construction of internal data. 204 SmallVector<Value, 4> sizes; 205 for (Value s : szs) 206 sizes.push_back(s); 207 params.push_back(genBuffer(rewriter, loc, sizes)); 208 // Dimension order permutation array. This is the "identity" permutation by 209 // default, or otherwise the "reverse" permutation of a given ordering, so 210 // that indices can be mapped quickly to the right position. 211 SmallVector<Value, 4> rev(sz); 212 if (AffineMap p = enc.getDimOrdering()) { 213 for (unsigned i = 0; i < sz; i++) 214 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 215 } else { 216 for (unsigned i = 0; i < sz; i++) 217 rev[i] = constantIndex(rewriter, loc, i); 218 } 219 params.push_back(genBuffer(rewriter, loc, rev)); 220 // Secondary and primary types encoding. 221 Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType(); 222 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 223 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 224 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 225 // User action. 226 params.push_back(constantAction(rewriter, loc, action)); 227 // Payload pointer. 228 if (!ptr) 229 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 230 params.push_back(ptr); 231 } 232 233 /// Generates the code to read the value from tensor[ivs], and conditionally 234 /// stores the indices ivs to the memory in ind. The generated code looks like 235 /// the following and the insertion point after this routine is inside the 236 /// if-then branch behind the assignment to ind. This is to ensure that the 237 /// addEltX call generated after is inside the if-then branch. 238 /// if (tensor[ivs]!=0) { 239 /// ind = ivs 240 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 241 Location loc, Value tensor, Value ind, 242 ValueRange ivs) { 243 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 244 Value cond = genIsNonzero(rewriter, loc, val); 245 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 246 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 247 unsigned i = 0; 248 for (auto iv : ivs) { 249 Value idx = constantIndex(rewriter, loc, i++); 250 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 251 } 252 return val; 253 } 254 255 /// Generates a call that adds one element to a coordinate scheme. 256 /// In particular, this generates code like the following: 257 /// val = a[i1,..,ik]; 258 /// if val != 0 259 /// t->add(val, [i1,..,ik], [p1,..,pk]); 260 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 261 Type eltType, Value ptr, Value val, Value ind, 262 Value perm) { 263 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 264 SmallVector<Value, 4> params{ptr, val, ind, perm}; 265 Type pTp = getOpaquePointerType(rewriter); 266 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 267 } 268 269 /// Generates a call to `iter->getNext()`. If there is a next element, 270 /// then it is copied into the out-parameters `ind` and `elemPtr`, 271 /// and the return value is true. If there isn't a next element, then 272 /// the memory for `iter` is freed and the return value is false. 273 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 274 Value iter, Value ind, Value elemPtr) { 275 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 276 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 277 SmallVector<Value, 3> params{iter, ind, elemPtr}; 278 Type i1 = rewriter.getI1Type(); 279 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 280 .getResult(0); 281 } 282 283 /// If the tensor is a sparse constant, generates and returns the pair of 284 /// the constants for the indices and the values. 285 static Optional<std::pair<Value, Value>> 286 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 287 Value tensor) { 288 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 289 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 290 DenseElementsAttr indicesAttr = attr.getIndices(); 291 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 292 DenseElementsAttr valuesAttr = attr.getValues(); 293 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 294 return std::make_pair(indices, values); 295 } 296 } 297 return {}; 298 } 299 300 /// Generates the code to copy the index at indices[ivs] to ind, and return 301 /// the value at value[ivs]. 302 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 303 Location loc, Value indices, 304 Value values, Value ind, ValueRange ivs, 305 unsigned rank) { 306 for (unsigned i = 0; i < rank; i++) { 307 Value idx = constantIndex(rewriter, loc, i); 308 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 309 ValueRange{ivs[0], idx}); 310 val = 311 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 312 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 313 } 314 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 315 } 316 317 /// Generates code to allocate a tensor of the given type, and zero 318 /// initialize it. If the tensor type has any dynamic sizes, then the 319 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 320 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 321 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 322 RankedTensorType tensorTp, ValueRange sizes) { 323 Type elemTp = tensorTp.getElementType(); 324 auto shape = tensorTp.getShape(); 325 auto memTp = MemRefType::get(shape, elemTp); 326 SmallVector<Value> dynamicSizes; 327 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 328 if (shape[i] == ShapedType::kDynamicSize) 329 dynamicSizes.push_back(sizes[i]); 330 } 331 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 332 Value zero = constantZero(rewriter, loc, elemTp); 333 rewriter.create<linalg::FillOp>(loc, zero, mem); 334 return mem; 335 } 336 337 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 338 /// the tensor created by allocDenseTensor(). The `rank` is the rank 339 /// of the `tensor` and the length of `ind`. 340 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 341 Location loc, Value elemPtr, 342 Value tensor, unsigned rank, 343 Value ind) { 344 SmallVector<Value, 4> ivs; 345 ivs.reserve(rank); 346 for (unsigned i = 0; i < rank; i++) { 347 Value idx = constantIndex(rewriter, loc, i); 348 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 349 } 350 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 351 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // Conversion rules. 356 //===----------------------------------------------------------------------===// 357 358 /// Sparse conversion rule for returns. 359 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 360 public: 361 using OpConversionPattern::OpConversionPattern; 362 LogicalResult 363 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 364 ConversionPatternRewriter &rewriter) const override { 365 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 366 return success(); 367 } 368 }; 369 370 /// Sparse conversion rule for dimension accesses. 371 class SparseTensorToDimSizeConverter 372 : public OpConversionPattern<tensor::DimOp> { 373 public: 374 using OpConversionPattern::OpConversionPattern; 375 LogicalResult 376 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 377 ConversionPatternRewriter &rewriter) const override { 378 // Only rewrite annotated DimOp with constant index. 379 auto enc = getSparseTensorEncoding(op.source().getType()); 380 if (!enc) 381 return failure(); 382 Optional<int64_t> index = op.getConstantIndex(); 383 if (!index.hasValue()) 384 return failure(); 385 // Generate the call. 386 Value src = adaptor.getOperands()[0]; 387 int64_t idx = index.getValue(); 388 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 389 return success(); 390 } 391 }; 392 393 /// Sparse conversion rule for trivial tensor casts. 394 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 395 using OpConversionPattern::OpConversionPattern; 396 LogicalResult 397 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 398 ConversionPatternRewriter &rewriter) const override { 399 // Only rewrite identically annotated source/dest. 400 auto encDst = getSparseTensorEncoding(op.getType()); 401 auto encSrc = getSparseTensorEncoding(op.source().getType()); 402 if (!encDst || encDst != encSrc) 403 return failure(); 404 rewriter.replaceOp(op, adaptor.getOperands()); 405 return success(); 406 } 407 }; 408 409 /// Sparse conversion rule for the new operator. 410 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 411 using OpConversionPattern::OpConversionPattern; 412 LogicalResult 413 matchAndRewrite(NewOp op, OpAdaptor adaptor, 414 ConversionPatternRewriter &rewriter) const override { 415 Type resType = op.getType(); 416 auto enc = getSparseTensorEncoding(resType); 417 if (!enc) 418 return failure(); 419 // Generate the call to construct tensor from ptr. The sizes are 420 // inferred from the result type of the new operator. 421 SmallVector<Value, 4> sizes; 422 SmallVector<Value, 8> params; 423 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 424 Value ptr = adaptor.getOperands()[0]; 425 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 426 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 427 return success(); 428 } 429 }; 430 431 /// Sparse conversion rule for the init operator. 432 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 433 using OpConversionPattern::OpConversionPattern; 434 LogicalResult 435 matchAndRewrite(InitOp op, OpAdaptor adaptor, 436 ConversionPatternRewriter &rewriter) const override { 437 Type resType = op.getType(); 438 auto enc = getSparseTensorEncoding(resType); 439 if (!enc) 440 return failure(); 441 // Generate the call to construct empty tensor. The sizes are 442 // explicitly defined by the arguments to the init operator. 443 SmallVector<Value, 8> params; 444 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 445 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 446 return success(); 447 } 448 }; 449 450 /// Sparse conversion rule for the convert operator. 451 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 452 using OpConversionPattern::OpConversionPattern; 453 LogicalResult 454 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 455 ConversionPatternRewriter &rewriter) const override { 456 Location loc = op->getLoc(); 457 Type resType = op.getType(); 458 Type srcType = op.source().getType(); 459 auto encDst = getSparseTensorEncoding(resType); 460 auto encSrc = getSparseTensorEncoding(srcType); 461 Value src = adaptor.getOperands()[0]; 462 if (encDst && encSrc) { 463 // This is a sparse => sparse conversion, which is handled as follows: 464 // t = src->toCOO(); ; src to COO in dst order 465 // dst = newSparseTensor(t) 466 // Using the coordinate scheme as an intermediate does not always 467 // yield the fastest conversion but avoids the need for a full 468 // O(N^2) conversion matrix. 469 if (encDst == encSrc) { 470 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 471 return success(); 472 } 473 SmallVector<Value, 4> sizes; 474 SmallVector<Value, 8> params; 475 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 476 src); 477 // Set up encoding with right mix of src and dst so that the two 478 // method calls can share most parameters, while still providing 479 // the correct sparsity information to either of them. 480 auto enc = SparseTensorEncodingAttr::get( 481 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 482 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 483 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 484 Value coo = genNewCall(rewriter, op, params); 485 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 486 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 487 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 488 params[7] = coo; 489 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 490 return success(); 491 } 492 if (!encDst && encSrc) { 493 // This is sparse => dense conversion, which is handled as follows: 494 // dst = new Tensor(0); 495 // iter = src->toCOO(); 496 // iter->startIterator(); 497 // while (elem = iter->getNext()) { 498 // dst[elem.indices] = elem.value; 499 // } 500 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 501 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 502 unsigned rank = dstTensorTp.getRank(); 503 Type elemTp = dstTensorTp.getElementType(); 504 // Fabricate a no-permutation encoding for newParams(). 505 // The pointer/index types must be those of `src`. 506 // The dimLevelTypes aren't actually used by Action::kToIterator. 507 encDst = SparseTensorEncodingAttr::get( 508 op->getContext(), 509 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 510 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 511 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 512 SmallVector<Value, 4> sizes; 513 SmallVector<Value, 8> params; 514 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 515 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 516 Value iter = genNewCall(rewriter, op, params); 517 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 518 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 519 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 520 SmallVector<Value> noArgs; 521 SmallVector<Type> noTypes; 522 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 523 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 524 rewriter.setInsertionPointToEnd(before); 525 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 526 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 527 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 528 rewriter.setInsertionPointToStart(after); 529 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 530 rewriter.create<scf::YieldOp>(loc); 531 rewriter.setInsertionPointAfter(whileOp); 532 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 533 return success(); 534 } 535 if (!encDst && !encSrc) { 536 // dense => dense 537 return failure(); 538 } 539 // This is a dense => sparse conversion or a sparse constant in COO => 540 // sparse conversion, which is handled as follows: 541 // t = newSparseCOO() 542 // ...code to fill the COO tensor t... 543 // s = newSparseTensor(t) 544 // 545 // To fill the COO tensor from a dense tensor: 546 // for i1 in dim1 547 // .. 548 // for ik in dimk 549 // val = a[i1,..,ik] 550 // if val != 0 551 // t->add(val, [i1,..,ik], [p1,..,pk]) 552 // 553 // To fill the COO tensor from a sparse constant in COO format: 554 // for i in range(NNZ) 555 // val = values[i] 556 // [i1,..,ik] = indices[i] 557 // t->add(val, [i1,..,ik], [p1,..,pk]) 558 // 559 // Note that the dense tensor traversal code is actually implemented 560 // using MLIR IR to avoid having to expose too much low-level 561 // memref traversal details to the runtime support library. 562 // Also note that the code below only generates the "new" ops and 563 // the loop-nest per se; whereas the entire body of the innermost 564 // loop is generated by genAddElt(). 565 ShapedType stp = resType.cast<ShapedType>(); 566 unsigned rank = stp.getRank(); 567 SmallVector<Value, 4> sizes; 568 SmallVector<Value, 8> params; 569 sizesFromSrc(rewriter, sizes, loc, src); 570 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 571 Value ptr = genNewCall(rewriter, op, params); 572 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 573 Value perm = params[2]; 574 SmallVector<Value> lo; 575 SmallVector<Value> hi; 576 SmallVector<Value> st; 577 Value zero = constantIndex(rewriter, loc, 0); 578 Value one = constantIndex(rewriter, loc, 1); 579 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 580 bool isCOOConstant = indicesValues.hasValue(); 581 Value indices; 582 Value values; 583 if (isCOOConstant) { 584 indices = indicesValues->first; 585 values = indicesValues->second; 586 lo.push_back(zero); 587 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 588 st.push_back(one); 589 } else { 590 for (unsigned i = 0; i < rank; i++) { 591 lo.push_back(zero); 592 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 593 st.push_back(one); 594 } 595 } 596 Type eltType = stp.getElementType(); 597 scf::buildLoopNest( 598 rewriter, op.getLoc(), lo, hi, st, {}, 599 [&](OpBuilder &builder, Location loc, ValueRange ivs, 600 ValueRange args) -> scf::ValueVector { 601 Value val; 602 if (isCOOConstant) 603 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 604 ivs, rank); 605 else 606 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 607 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 608 return {}; 609 }); 610 // Final call to construct sparse tensor storage. 611 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 612 params[7] = ptr; 613 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 614 return success(); 615 } 616 }; 617 618 /// Sparse conversion rule for the release operator. 619 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 620 public: 621 using OpConversionPattern::OpConversionPattern; 622 LogicalResult 623 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 624 ConversionPatternRewriter &rewriter) const override { 625 StringRef name = "delSparseTensor"; 626 TypeRange noTp; 627 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 628 EmitCInterface::Off); 629 rewriter.eraseOp(op); 630 return success(); 631 } 632 }; 633 634 /// Sparse conversion rule for pointer accesses. 635 class SparseTensorToPointersConverter 636 : public OpConversionPattern<ToPointersOp> { 637 public: 638 using OpConversionPattern::OpConversionPattern; 639 LogicalResult 640 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 641 ConversionPatternRewriter &rewriter) const override { 642 Type resType = op.getType(); 643 Type ptrType = resType.cast<ShapedType>().getElementType(); 644 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 645 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 646 EmitCInterface::On); 647 return success(); 648 } 649 }; 650 651 /// Sparse conversion rule for index accesses. 652 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 653 public: 654 using OpConversionPattern::OpConversionPattern; 655 LogicalResult 656 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 657 ConversionPatternRewriter &rewriter) const override { 658 Type resType = op.getType(); 659 Type indType = resType.cast<ShapedType>().getElementType(); 660 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 661 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 662 EmitCInterface::On); 663 return success(); 664 } 665 }; 666 667 /// Sparse conversion rule for value accesses. 668 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 669 public: 670 using OpConversionPattern::OpConversionPattern; 671 LogicalResult 672 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 673 ConversionPatternRewriter &rewriter) const override { 674 Type resType = op.getType(); 675 Type eltType = resType.cast<ShapedType>().getElementType(); 676 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 677 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 678 EmitCInterface::On); 679 return success(); 680 } 681 }; 682 683 /// Sparse conversion rule for tensor rematerialization. 684 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 685 public: 686 using OpConversionPattern::OpConversionPattern; 687 LogicalResult 688 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 689 ConversionPatternRewriter &rewriter) const override { 690 if (op.hasInserts()) { 691 // Finalize any pending insertions. 692 StringRef name = "endInsert"; 693 TypeRange noTp; 694 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 695 EmitCInterface::Off); 696 } 697 rewriter.replaceOp(op, adaptor.getOperands()); 698 return success(); 699 } 700 }; 701 702 /// Sparse conversion rule for inserting in lexicographic index order. 703 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 704 public: 705 using OpConversionPattern::OpConversionPattern; 706 LogicalResult 707 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 708 ConversionPatternRewriter &rewriter) const override { 709 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 710 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 711 TypeRange noTp; 712 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 713 EmitCInterface::On); 714 return success(); 715 } 716 }; 717 718 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 719 public: 720 using OpConversionPattern::OpConversionPattern; 721 LogicalResult 722 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 723 ConversionPatternRewriter &rewriter) const override { 724 Location loc = op->getLoc(); 725 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 726 Type eltType = srcType.getElementType(); 727 Type boolType = rewriter.getIntegerType(1); 728 Type idxType = rewriter.getIndexType(); 729 // All initialization should be done on entry of the loop nest. 730 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 731 // Determine the size for access expansion. 732 auto enc = getSparseTensorEncoding(srcType); 733 Value src = adaptor.getOperands()[0]; 734 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 735 // Allocate temporary stack buffers for values, filled-switch, and indices. 736 Value values = genAlloca(rewriter, loc, sz, eltType); 737 Value filled = genAlloca(rewriter, loc, sz, boolType); 738 Value indices = genAlloca(rewriter, loc, sz, idxType); 739 Value zero = constantZero(rewriter, loc, idxType); 740 // Reset the values/filled-switch to all-zero/false. Note that this 741 // introduces an O(N) operation into the computation, but this reset 742 // operation is amortized over the innermost loops for the access 743 // pattern expansion. 744 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType), 745 values); 746 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType), 747 filled); 748 // Replace expansion op with these buffers and initial index. 749 assert(op.getNumResults() == 4); 750 rewriter.replaceOp(op, {values, filled, indices, zero}); 751 return success(); 752 } 753 }; 754 755 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 756 public: 757 using OpConversionPattern::OpConversionPattern; 758 LogicalResult 759 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 760 ConversionPatternRewriter &rewriter) const override { 761 // Note that this method call resets the values/filled-switch back to 762 // all-zero/false by only iterating over the set elements, so the 763 // complexity remains proportional to the sparsity of the expanded 764 // access pattern. 765 Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType(); 766 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 767 TypeRange noTp; 768 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 769 EmitCInterface::On); 770 return success(); 771 } 772 }; 773 774 } // namespace 775 776 //===----------------------------------------------------------------------===// 777 // Public method for populating conversion rules. 778 //===----------------------------------------------------------------------===// 779 780 /// Populates the given patterns list with conversion rules required for 781 /// the sparsification of linear algebra operations. 782 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 783 RewritePatternSet &patterns) { 784 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 785 SparseCastConverter, SparseTensorNewConverter, 786 SparseTensorInitConverter, SparseTensorConvertConverter, 787 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 788 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 789 SparseTensorLoadConverter, SparseTensorLexInsertConverter, 790 SparseTensorExpandConverter, SparseTensorCompressConverter>( 791 typeConverter, patterns.getContext()); 792 } 793