1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 37 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 38 enum class EmitCInterface : bool { Off = false, On = true }; 39 40 //===----------------------------------------------------------------------===// 41 // Helper methods. 42 //===----------------------------------------------------------------------===// 43 44 /// Returns the equivalent of `void*` for opaque arguments to the 45 /// execution engine. 46 static Type getOpaquePointerType(OpBuilder &builder) { 47 return LLVM::LLVMPointerType::get(builder.getI8Type()); 48 } 49 50 /// Returns a function reference (first hit also inserts into module). Sets 51 /// the "_emit_c_interface" on the function declaration when requested, 52 /// so that LLVM lowering generates a wrapper function that takes care 53 /// of ABI complications with passing in and returning MemRefs to C functions. 54 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 55 TypeRange resultType, ValueRange operands, 56 EmitCInterface emitCInterface) { 57 MLIRContext *context = op->getContext(); 58 auto module = op->getParentOfType<ModuleOp>(); 59 auto result = SymbolRefAttr::get(context, name); 60 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 61 if (!func) { 62 OpBuilder moduleBuilder(module.getBodyRegion()); 63 func = moduleBuilder.create<func::FuncOp>( 64 op->getLoc(), name, 65 FunctionType::get(context, operands.getTypes(), resultType)); 66 func.setPrivate(); 67 if (static_cast<bool>(emitCInterface)) 68 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 69 UnitAttr::get(context)); 70 } 71 return result; 72 } 73 74 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 75 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 76 StringRef name, TypeRange resultType, 77 ValueRange operands, 78 EmitCInterface emitCInterface) { 79 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 80 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 81 } 82 83 /// Replaces the `op` with a `CallOp` to the function reference returned 84 /// by `getFunc()`. 85 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 86 StringRef name, TypeRange resultType, 87 ValueRange operands, 88 EmitCInterface emitCInterface) { 89 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 90 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 91 operands); 92 } 93 94 /// Generates dimension size call. 95 static Value genDimSizeCall(OpBuilder &builder, Operation *op, 96 SparseTensorEncodingAttr &enc, Value src, 97 int64_t idx) { 98 // Permute the index according to an optional dimension ordering. 99 if (AffineMap p = enc.getDimOrdering()) 100 idx = p.getPermutedPosition(idx); 101 // Generate the call. 102 StringRef name = "sparseDimSize"; 103 SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)}; 104 Type iTp = builder.getIndexType(); 105 return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off) 106 .getResult(0); 107 } 108 109 /// Generates a call into the "swiss army knife" method of the sparse runtime 110 /// support library for materializing sparse tensors into the computation. 111 static Value genNewCall(OpBuilder &builder, Operation *op, 112 ArrayRef<Value> params) { 113 StringRef name = "newSparseTensor"; 114 Type pTp = getOpaquePointerType(builder); 115 return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On) 116 .getResult(0); 117 } 118 119 /// Populates given sizes array from type. 120 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes, 121 Location loc, ShapedType stp) { 122 auto shape = stp.getShape(); 123 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 124 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 125 sizes.push_back(constantIndex(builder, loc, s)); 126 } 127 } 128 129 /// Populates given sizes array from source. 130 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes, 131 Location loc, Value src) { 132 unsigned rank = src.getType().cast<ShapedType>().getRank(); 133 for (unsigned i = 0; i < rank; i++) 134 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); 135 } 136 137 /// Populates given sizes array from type (for static sizes) and from 138 /// an already converted into opague pointer source (for dynamic sizes). 139 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes, 140 Operation *op, SparseTensorEncodingAttr &enc, 141 ShapedType stp, Value src) { 142 Location loc = op->getLoc(); 143 auto shape = stp.getShape(); 144 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 145 if (shape[i] == ShapedType::kDynamicSize) 146 sizes.push_back(genDimSizeCall(builder, op, enc, src, i)); 147 else 148 sizes.push_back(constantIndex(builder, loc, shape[i])); 149 } 150 151 /// Generates an uninitialized temporary buffer of the given size and 152 /// type, but returns it as type `memref<? x $tp>` (rather than as type 153 /// `memref<$sz x $tp>`). 154 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { 155 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 156 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 157 } 158 159 /// Generates an uninitialized buffer of the given size and type, 160 /// but returns it as type `memref<? x $tp>` (rather than as type 161 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 162 /// this buffer must be explicitly deallocated by client. 163 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 164 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 165 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 166 } 167 168 /// Generates an uninitialized temporary buffer of the given size and 169 /// type, but returns it as type `memref<? x $tp>` (rather than as type 170 /// `memref<$sz x $tp>`). 171 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) { 172 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 173 } 174 175 /// Generates an uninitialized temporary buffer with room for one value 176 /// of the given type, and returns the `memref<$tp>`. 177 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { 178 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 179 } 180 181 /// Generates a temporary buffer of the given type and given contents. 182 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { 183 unsigned sz = values.size(); 184 assert(sz >= 1); 185 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 186 for (unsigned i = 0; i < sz; i++) { 187 Value idx = constantIndex(builder, loc, i); 188 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 189 } 190 return buffer; 191 } 192 193 /// Populates parameters required to call the "swiss army knife" method of the 194 /// sparse runtime support library for materializing sparse tensors into the 195 /// computation. 196 static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms, 197 Operation *op, ShapedType stp, 198 SparseTensorEncodingAttr &enc, Action action, 199 ValueRange szs, Value ptr = Value()) { 200 Location loc = op->getLoc(); 201 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 202 unsigned sz = dlt.size(); 203 // Sparsity annotations. 204 SmallVector<Value, 4> attrs; 205 for (unsigned i = 0; i < sz; i++) 206 attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); 207 params.push_back(genBuffer(builder, loc, attrs)); 208 // Dimension sizes array of the enveloping tensor. Useful for either 209 // verification of external data, or for construction of internal data. 210 params.push_back(genBuffer(builder, loc, szs)); 211 // Dimension order permutation array. This is the "identity" permutation by 212 // default, or otherwise the "reverse" permutation of a given ordering, so 213 // that indices can be mapped quickly to the right position. 214 SmallVector<Value, 4> rev(sz); 215 if (AffineMap p = enc.getDimOrdering()) { 216 for (unsigned i = 0; i < sz; i++) 217 rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); 218 } else { 219 for (unsigned i = 0; i < sz; i++) 220 rev[i] = constantIndex(builder, loc, i); 221 } 222 params.push_back(genBuffer(builder, loc, rev)); 223 // Secondary and primary types encoding. 224 Type elemTp = stp.getElementType(); 225 params.push_back(constantPointerTypeEncoding(builder, loc, enc)); 226 params.push_back(constantIndexTypeEncoding(builder, loc, enc)); 227 params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); 228 // User action. 229 params.push_back(constantAction(builder, loc, action)); 230 // Payload pointer. 231 if (!ptr) 232 ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder)); 233 params.push_back(ptr); 234 } 235 236 /// Generates the code to read the value from tensor[ivs], and conditionally 237 /// stores the indices ivs to the memory in ind. The generated code looks like 238 /// the following and the insertion point after this routine is inside the 239 /// if-then branch behind the assignment to ind. This is to ensure that the 240 /// addEltX call generated after is inside the if-then branch. 241 /// if (tensor[ivs] != 0) 242 /// ind = ivs 243 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, 244 Value tensor, Value ind, ValueRange ivs) { 245 Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs); 246 Value cond = genIsNonzero(builder, loc, val); 247 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false); 248 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 249 unsigned i = 0; 250 for (auto iv : ivs) { 251 Value idx = constantIndex(builder, loc, i++); 252 builder.create<memref::StoreOp>(loc, iv, ind, idx); 253 } 254 return val; 255 } 256 257 /// Generates a call to release/delete a `SparseTensorCOO`. 258 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 259 Value coo) { 260 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 261 TypeRange noTp; 262 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 263 } 264 265 /// Generates a call that adds one element to a coordinate scheme. 266 /// In particular, this generates code like the following: 267 /// val = a[i1,..,ik]; 268 /// if val != 0 269 /// t->add(&val, [i1,..,ik], [p1,..,pk]); 270 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType, 271 Value ptr, Value valPtr, Value ind, Value perm) { 272 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 273 SmallVector<Value, 4> params{ptr, valPtr, ind, perm}; 274 Type pTp = getOpaquePointerType(builder); 275 createFuncCall(builder, op, name, pTp, params, EmitCInterface::On); 276 } 277 278 /// Generates a call to `iter->getNext()`. If there is a next element, 279 /// then it is copied into the out-parameters `ind` and `elemPtr`, 280 /// and the return value is true. If there isn't a next element, then 281 /// the memory for `iter` is freed and the return value is false. 282 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter, 283 Value ind, Value elemPtr) { 284 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 285 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 286 SmallVector<Value, 3> params{iter, ind, elemPtr}; 287 Type i1 = builder.getI1Type(); 288 return createFuncCall(builder, op, name, i1, params, EmitCInterface::On) 289 .getResult(0); 290 } 291 292 /// If the tensor is a sparse constant, generates and returns the pair of 293 /// the constants for the indices and the values. 294 static Optional<std::pair<Value, Value>> 295 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { 296 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 297 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 298 DenseElementsAttr indicesAttr = attr.getIndices(); 299 Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr); 300 DenseElementsAttr valuesAttr = attr.getValues(); 301 Value values = builder.create<arith::ConstantOp>(loc, valuesAttr); 302 return std::make_pair(indices, values); 303 } 304 } 305 return {}; 306 } 307 308 /// Generates the code to copy the index at indices[ivs] to ind, and return 309 /// the value at value[ivs]. 310 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc, 311 Value indices, Value values, Value ind, 312 ValueRange ivs, unsigned rank) { 313 for (unsigned i = 0; i < rank; i++) { 314 Value idx = constantIndex(builder, loc, i); 315 Value val = builder.create<tensor::ExtractOp>(loc, indices, 316 ValueRange{ivs[0], idx}); 317 val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val); 318 builder.create<memref::StoreOp>(loc, val, ind, idx); 319 } 320 return builder.create<tensor::ExtractOp>(loc, values, ivs[0]); 321 } 322 323 /// Generates code to allocate a tensor of the given type, and zero 324 /// initialize it. If the tensor type has any dynamic sizes, then the 325 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 326 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 327 static Value allocDenseTensor(OpBuilder &builder, Location loc, 328 RankedTensorType tensorTp, ValueRange sizes) { 329 Type elemTp = tensorTp.getElementType(); 330 auto shape = tensorTp.getShape(); 331 auto memTp = MemRefType::get(shape, elemTp); 332 SmallVector<Value> dynamicSizes; 333 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 334 if (shape[i] == ShapedType::kDynamicSize) 335 dynamicSizes.push_back(sizes[i]); 336 } 337 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 338 Value zero = constantZero(builder, loc, elemTp); 339 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 340 return mem; 341 } 342 343 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 344 /// the tensor created by allocDenseTensor(). The `rank` is the rank 345 /// of the `tensor` and the length of `ind`. 346 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, 347 Value elemPtr, Value tensor, 348 unsigned rank, Value ind) { 349 SmallVector<Value, 4> ivs; 350 ivs.reserve(rank); 351 for (unsigned i = 0; i < rank; i++) { 352 Value idx = constantIndex(builder, loc, i); 353 ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx)); 354 } 355 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr); 356 builder.create<memref::StoreOp>(loc, elemV, tensor, ivs); 357 } 358 359 /// Determine if the runtime library supports direct conversion to the 360 /// given target `dimTypes`. 361 static bool canUseDirectConversion( 362 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) { 363 bool alreadyCompressed = false; 364 for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { 365 switch (dimTypes[r]) { 366 case SparseTensorEncodingAttr::DimLevelType::Compressed: 367 if (alreadyCompressed) 368 return false; // Multiple compressed dimensions not yet supported. 369 alreadyCompressed = true; 370 break; 371 case SparseTensorEncodingAttr::DimLevelType::Dense: 372 if (alreadyCompressed) 373 return false; // Dense after Compressed not yet supported. 374 break; 375 case SparseTensorEncodingAttr::DimLevelType::Singleton: 376 // Although Singleton isn't generally supported yet, the direct 377 // conversion method doesn't have any particular problems with 378 // singleton after compressed. 379 break; 380 } 381 } 382 return true; 383 } 384 385 /// Helper method to translate indices during a reshaping operation. 386 /// TODO: provide as general utility to MLIR at large? 387 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, 388 ArrayRef<ReassociationIndices> reassociation, 389 TensorType dstTp, TensorType srcTp, Value dstIdx, 390 Value srcIdx) { 391 unsigned dstRank = dstTp.getRank(); 392 unsigned srcRank = srcTp.getRank(); 393 unsigned start = 0; 394 unsigned i = 0; 395 bool isExpand = srcRank > dstRank; 396 ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape(); 397 // Iterate over reassociation map. 398 for (const auto &map : llvm::enumerate(reassociation)) { 399 // Prepare strides information in dimension slice. 400 uint64_t linear = 1; 401 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 402 assert(!ShapedType::isDynamic(shape[j])); 403 linear *= shape[j]; 404 } 405 // Start collapse. 406 Value idx = constantIndex(rewriter, loc, i++); 407 Value val; 408 if (!isExpand) 409 val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx); 410 // Iterate over dimension slice. 411 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 412 linear /= shape[j]; 413 Value stride = constantIndex(rewriter, loc, linear); 414 Value jdx = constantIndex(rewriter, loc, j); 415 if (isExpand) { 416 Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx); 417 Value mul = linear == 1 418 ? old 419 : rewriter.create<arith::MulIOp>(loc, old, stride); 420 val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul; 421 } else { 422 Value old = val; 423 if (linear != 1) 424 val = rewriter.create<arith::DivUIOp>(loc, val, stride); 425 rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx); 426 if (linear != 1) 427 val = rewriter.create<arith::RemUIOp>(loc, old, stride); 428 } 429 } 430 // Finalize expansion. 431 if (isExpand) 432 rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx); 433 start += map.value().size(); 434 } 435 // Sanity. 436 assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); 437 } 438 439 /// Generate code for a general sparse to sparse reshaping operation. 440 /// Note that unlike dense reshaping (which can be done with a "cheap" 441 /// change of view), sparse reshaping is currently done with actual 442 /// data shuffling. 443 /// 444 /// TODO: proportional to nnz, but still a lot of data movement 445 /// https://github.com/llvm/llvm-project/issues/56477 446 /// 447 /// iter = src->toCOO(); 448 /// coo = newSparseCOO() 449 /// while (elem = iter->getNext()) { 450 /// coo->add(reshape(elem.indices), elem.value) 451 /// } 452 /// s = newSparseTensor(coo) 453 static LogicalResult 454 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter, 455 ArrayRef<ReassociationIndices> reassociation, Value src, 456 RankedTensorType dstTp, RankedTensorType srcTp) { 457 Location loc = op->getLoc(); 458 auto encDst = getSparseTensorEncoding(dstTp); 459 auto encSrc = getSparseTensorEncoding(srcTp); 460 assert(encDst && encSrc); 461 unsigned srcRank = srcTp.getRank(); 462 unsigned dstRank = dstTp.getRank(); 463 Type elemTp = srcTp.getElementType(); 464 assert(elemTp == dstTp.getElementType() && 465 "reshape should not change element type"); 466 // Start an iterator over the source tensor (in original index order). 467 auto noPerm = SparseTensorEncodingAttr::get( 468 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 469 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 470 SmallVector<Value, 4> sizes; 471 SmallVector<Value, 8> params; 472 sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); 473 newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, 474 src); 475 Value iter = genNewCall(rewriter, op, params); 476 // Start a new COO for the destination tensor. 477 sizes.clear(); 478 params.clear(); 479 sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); 480 newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); 481 Value coo = genNewCall(rewriter, op, params); 482 Value dstPerm = params[2]; 483 // Construct a while loop over the iterator. 484 Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType()); 485 Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType()); 486 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 487 SmallVector<Value> noArgs; 488 SmallVector<Type> noTypes; 489 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 490 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 491 rewriter.setInsertionPointToEnd(before); 492 Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr); 493 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 494 // Translate indices from source to target and insert. Note that we do 495 // not need to store the value in elemPtr, as the value is still there. 496 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 497 rewriter.setInsertionPointToStart(after); 498 translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); 499 genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); 500 rewriter.create<scf::YieldOp>(loc); 501 // Final call to construct sparse tensor storage and free temporary resources. 502 rewriter.setInsertionPointAfter(whileOp); 503 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 504 params[7] = coo; 505 Value dst = genNewCall(rewriter, op, params); 506 genDelCOOCall(rewriter, op, elemTp, coo); 507 genDelCOOCall(rewriter, op, elemTp, iter); 508 rewriter.replaceOp(op, dst); 509 return success(); 510 } 511 512 //===----------------------------------------------------------------------===// 513 // Conversion rules. 514 //===----------------------------------------------------------------------===// 515 516 /// Sparse conversion rule for returns. 517 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 518 public: 519 using OpConversionPattern::OpConversionPattern; 520 LogicalResult 521 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 522 ConversionPatternRewriter &rewriter) const override { 523 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 524 return success(); 525 } 526 }; 527 528 /// Sparse conversion rule for dimension accesses. 529 class SparseTensorToDimSizeConverter 530 : public OpConversionPattern<tensor::DimOp> { 531 public: 532 using OpConversionPattern::OpConversionPattern; 533 LogicalResult 534 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 535 ConversionPatternRewriter &rewriter) const override { 536 // Only rewrite annotated DimOp with constant index. 537 auto enc = getSparseTensorEncoding(op.getSource().getType()); 538 if (!enc) 539 return failure(); 540 Optional<int64_t> index = op.getConstantIndex(); 541 if (!index) 542 return failure(); 543 // Generate the call. 544 Value src = adaptor.getOperands()[0]; 545 int64_t idx = *index; 546 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 547 return success(); 548 } 549 }; 550 551 /// Sparse conversion rule for trivial tensor casts. 552 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 553 public: 554 using OpConversionPattern::OpConversionPattern; 555 LogicalResult 556 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 557 ConversionPatternRewriter &rewriter) const override { 558 // Only rewrite identically annotated source/dest. 559 auto encDst = getSparseTensorEncoding(op.getType()); 560 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 561 if (!encDst || encDst != encSrc) 562 return failure(); 563 rewriter.replaceOp(op, adaptor.getOperands()); 564 return success(); 565 } 566 }; 567 568 /// Sparse conversion rule for a reshape operator. 569 template <typename ReshapeOp> 570 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> { 571 public: 572 using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor; 573 using OpConversionPattern<ReshapeOp>::OpConversionPattern; 574 LogicalResult 575 matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, 576 ConversionPatternRewriter &rewriter) const override { 577 Type dstType = op.getResult().getType(); 578 Type srcType = op.getSrc().getType(); 579 auto encDst = getSparseTensorEncoding(dstType); 580 auto encSrc = getSparseTensorEncoding(srcType); 581 if (encDst && encSrc) 582 return genSparse2SparseReshape( 583 op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0], 584 dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>()); 585 return failure(); // handled elsewhere 586 } 587 }; 588 589 /// Sparse conversion rule for the new operator. 590 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 591 public: 592 using OpConversionPattern::OpConversionPattern; 593 LogicalResult 594 matchAndRewrite(NewOp op, OpAdaptor adaptor, 595 ConversionPatternRewriter &rewriter) const override { 596 Type resType = op.getType(); 597 auto enc = getSparseTensorEncoding(resType); 598 if (!enc) 599 return failure(); 600 // Generate the call to construct tensor from ptr. The sizes are 601 // inferred from the result type of the new operator. 602 SmallVector<Value, 4> sizes; 603 SmallVector<Value, 8> params; 604 ShapedType stp = resType.cast<ShapedType>(); 605 sizesFromType(rewriter, sizes, op.getLoc(), stp); 606 Value ptr = adaptor.getOperands()[0]; 607 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 608 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 609 return success(); 610 } 611 }; 612 613 /// Sparse conversion rule for the alloc operator. 614 class SparseTensorAllocConverter 615 : public OpConversionPattern<bufferization::AllocTensorOp> { 616 public: 617 using OpConversionPattern::OpConversionPattern; 618 LogicalResult 619 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 620 ConversionPatternRewriter &rewriter) const override { 621 RankedTensorType resType = op.getType(); 622 auto enc = getSparseTensorEncoding(resType); 623 if (!enc) 624 return failure(); 625 // Gather all dimension sizes as SSA values. 626 SmallVector<Value> sizes; 627 unsigned int operandCtr = 0; 628 for (int64_t i = 0; i < resType.getRank(); ++i) { 629 if (resType.isDynamicDim(i)) { 630 sizes.push_back(adaptor.getOperands()[operandCtr++]); 631 } else { 632 sizes.push_back(rewriter.create<arith::ConstantIndexOp>( 633 op.getLoc(), op.getStaticSize(i))); 634 } 635 } 636 // Generate the call to construct empty tensor. The sizes are 637 // explicitly defined by the arguments to the alloc operator. 638 SmallVector<Value, 8> params; 639 ShapedType stp = resType.cast<ShapedType>(); 640 newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes); 641 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 642 return success(); 643 } 644 }; 645 646 /// Sparse conversion rule for the convert operator. 647 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 648 public: 649 using OpConversionPattern::OpConversionPattern; 650 SparseTensorConvertConverter(MLIRContext *context, 651 SparseTensorConversionOptions o) 652 : OpConversionPattern<ConvertOp>(context), options(o) {} 653 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 654 SparseTensorConversionOptions o) 655 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 656 657 LogicalResult 658 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 659 ConversionPatternRewriter &rewriter) const override { 660 Location loc = op->getLoc(); 661 Type resType = op.getType(); 662 Type srcType = op.getSource().getType(); 663 auto encDst = getSparseTensorEncoding(resType); 664 auto encSrc = getSparseTensorEncoding(srcType); 665 Value src = adaptor.getOperands()[0]; 666 if (encDst && encSrc) { 667 // This is a sparse => sparse conversion, which is handled as follows: 668 // t = src->toCOO(); ; src to COO in dst order 669 // dst = newSparseTensor(t) 670 // Using the coordinate scheme as an intermediate does not always 671 // yield the fastest conversion but avoids the need for a full 672 // O(N^2) conversion matrix. 673 if (encDst == encSrc) { 674 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 675 return success(); 676 } 677 SmallVector<Value, 4> sizes; 678 SmallVector<Value, 8> params; 679 ShapedType stp = srcType.cast<ShapedType>(); 680 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 681 bool useDirectConversion; 682 switch (options.sparseToSparseStrategy) { 683 case SparseToSparseConversionStrategy::kViaCOO: 684 useDirectConversion = false; 685 break; 686 case SparseToSparseConversionStrategy::kDirect: 687 useDirectConversion = true; 688 assert(canUseDirectConversion(encDst.getDimLevelType()) && 689 "Unsupported target for direct sparse-to-sparse conversion"); 690 break; 691 case SparseToSparseConversionStrategy::kAuto: 692 useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); 693 break; 694 } 695 if (useDirectConversion) { 696 newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse, 697 sizes, src); 698 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 699 } else { // use via-COO conversion. 700 // Set up encoding with right mix of src and dst so that the two 701 // method calls can share most parameters, while still providing 702 // the correct sparsity information to either of them. 703 auto enc = SparseTensorEncodingAttr::get( 704 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 705 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 706 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 707 Value coo = genNewCall(rewriter, op, params); 708 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 709 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 710 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 711 params[7] = coo; 712 Value dst = genNewCall(rewriter, op, params); 713 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 714 rewriter.replaceOp(op, dst); 715 } 716 return success(); 717 } 718 if (!encDst && encSrc) { 719 // This is sparse => dense conversion, which is handled as follows: 720 // dst = new Tensor(0); 721 // iter = src->toCOO(); 722 // iter->startIterator(); 723 // while (elem = iter->getNext()) { 724 // dst[elem.indices] = elem.value; 725 // } 726 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 727 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 728 unsigned rank = dstTensorTp.getRank(); 729 Type elemTp = dstTensorTp.getElementType(); 730 // Fabricate a no-permutation encoding for newParams(). 731 // The pointer/index types must be those of `src`. 732 // The dimLevelTypes aren't actually used by Action::kToIterator. 733 encDst = SparseTensorEncodingAttr::get( 734 op->getContext(), 735 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 736 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 737 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 738 SmallVector<Value, 4> sizes; 739 SmallVector<Value, 8> params; 740 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 741 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 742 sizes, src); 743 Value iter = genNewCall(rewriter, op, params); 744 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 745 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 746 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 747 SmallVector<Value> noArgs; 748 SmallVector<Type> noTypes; 749 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 750 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 751 rewriter.setInsertionPointToEnd(before); 752 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 753 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 754 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 755 rewriter.setInsertionPointToStart(after); 756 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 757 rewriter.create<scf::YieldOp>(loc); 758 rewriter.setInsertionPointAfter(whileOp); 759 genDelCOOCall(rewriter, op, elemTp, iter); 760 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 761 return success(); 762 } 763 if (!encDst && !encSrc) { 764 // dense => dense 765 return failure(); 766 } 767 // This is a dense => sparse conversion or a sparse constant in COO => 768 // sparse conversion, which is handled as follows: 769 // t = newSparseCOO() 770 // ...code to fill the COO tensor t... 771 // s = newSparseTensor(t) 772 // 773 // To fill the COO tensor from a dense tensor: 774 // for i1 in dim1 775 // .. 776 // for ik in dimk 777 // val = a[i1,..,ik] 778 // if val != 0 779 // t->add(val, [i1,..,ik], [p1,..,pk]) 780 // 781 // To fill the COO tensor from a sparse constant in COO format: 782 // for i in range(NNZ) 783 // val = values[i] 784 // [i1,..,ik] = indices[i] 785 // t->add(val, [i1,..,ik], [p1,..,pk]) 786 // 787 // Note that the dense tensor traversal code is actually implemented 788 // using MLIR IR to avoid having to expose too much low-level 789 // memref traversal details to the runtime support library. 790 // Also note that the code below only generates the "new" ops and 791 // the loop-nest per se; whereas the entire body of the innermost 792 // loop is generated by genAddElt(). 793 ShapedType stp = resType.cast<ShapedType>(); 794 unsigned rank = stp.getRank(); 795 SmallVector<Value, 4> sizes; 796 SmallVector<Value, 8> params; 797 sizesFromSrc(rewriter, sizes, loc, src); 798 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 799 Value coo = genNewCall(rewriter, op, params); 800 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 801 Value perm = params[2]; 802 SmallVector<Value> lo; 803 SmallVector<Value> hi; 804 SmallVector<Value> st; 805 Value zero = constantIndex(rewriter, loc, 0); 806 Value one = constantIndex(rewriter, loc, 1); 807 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 808 bool isCOOConstant = indicesValues.has_value(); 809 Value indices; 810 Value values; 811 if (isCOOConstant) { 812 indices = indicesValues->first; 813 values = indicesValues->second; 814 lo.push_back(zero); 815 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 816 st.push_back(one); 817 } else { 818 for (unsigned i = 0; i < rank; i++) { 819 lo.push_back(zero); 820 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 821 st.push_back(one); 822 } 823 } 824 Type eltType = stp.getElementType(); 825 Value elemPtr = genAllocaScalar(rewriter, loc, eltType); 826 scf::buildLoopNest( 827 rewriter, op.getLoc(), lo, hi, st, {}, 828 [&](OpBuilder &builder, Location loc, ValueRange ivs, 829 ValueRange args) -> scf::ValueVector { 830 Value val; 831 if (isCOOConstant) 832 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 833 ivs, rank); 834 else 835 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 836 builder.create<memref::StoreOp>(loc, val, elemPtr); 837 genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm); 838 return {}; 839 }); 840 // Final call to construct sparse tensor storage. 841 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 842 params[7] = coo; 843 Value dst = genNewCall(rewriter, op, params); 844 genDelCOOCall(rewriter, op, eltType, coo); 845 rewriter.replaceOp(op, dst); 846 return success(); 847 } 848 849 private: 850 /// Options to control sparse code generation. 851 SparseTensorConversionOptions options; 852 }; 853 854 /// Sparse conversion rule for the release operator. 855 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 856 public: 857 using OpConversionPattern::OpConversionPattern; 858 LogicalResult 859 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 860 ConversionPatternRewriter &rewriter) const override { 861 StringRef name = "delSparseTensor"; 862 TypeRange noTp; 863 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 864 EmitCInterface::Off); 865 rewriter.eraseOp(op); 866 return success(); 867 } 868 }; 869 870 /// Sparse conversion rule for pointer accesses. 871 class SparseTensorToPointersConverter 872 : public OpConversionPattern<ToPointersOp> { 873 public: 874 using OpConversionPattern::OpConversionPattern; 875 LogicalResult 876 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 877 ConversionPatternRewriter &rewriter) const override { 878 Type resType = op.getType(); 879 Type ptrType = resType.cast<ShapedType>().getElementType(); 880 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 881 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 882 EmitCInterface::On); 883 return success(); 884 } 885 }; 886 887 /// Sparse conversion rule for index accesses. 888 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 889 public: 890 using OpConversionPattern::OpConversionPattern; 891 LogicalResult 892 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 893 ConversionPatternRewriter &rewriter) const override { 894 Type resType = op.getType(); 895 Type indType = resType.cast<ShapedType>().getElementType(); 896 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 897 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 898 EmitCInterface::On); 899 return success(); 900 } 901 }; 902 903 /// Sparse conversion rule for value accesses. 904 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 905 public: 906 using OpConversionPattern::OpConversionPattern; 907 LogicalResult 908 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 909 ConversionPatternRewriter &rewriter) const override { 910 Type resType = op.getType(); 911 Type eltType = resType.cast<ShapedType>().getElementType(); 912 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 913 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 914 EmitCInterface::On); 915 return success(); 916 } 917 }; 918 919 /// Sparse conversion rule for tensor rematerialization. 920 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 921 public: 922 using OpConversionPattern::OpConversionPattern; 923 LogicalResult 924 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 925 ConversionPatternRewriter &rewriter) const override { 926 if (op.getHasInserts()) { 927 // Finalize any pending insertions. 928 StringRef name = "endInsert"; 929 TypeRange noTp; 930 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 931 EmitCInterface::Off); 932 } 933 rewriter.replaceOp(op, adaptor.getOperands()); 934 return success(); 935 } 936 }; 937 938 /// Sparse conversion rule for inserting in lexicographic index order. 939 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 940 public: 941 using OpConversionPattern::OpConversionPattern; 942 LogicalResult 943 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 944 ConversionPatternRewriter &rewriter) const override { 945 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 946 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 947 TypeRange noTp; 948 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 949 EmitCInterface::On); 950 return success(); 951 } 952 }; 953 954 /// Sparse conversion rule for the expand operator. 955 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 956 public: 957 using OpConversionPattern::OpConversionPattern; 958 LogicalResult 959 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 960 ConversionPatternRewriter &rewriter) const override { 961 Location loc = op->getLoc(); 962 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 963 Type eltType = srcType.getElementType(); 964 Type boolType = rewriter.getIntegerType(1); 965 Type idxType = rewriter.getIndexType(); 966 // All initialization should be done on entry of the loop nest. 967 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 968 // Determine the size for access expansion. 969 auto enc = getSparseTensorEncoding(srcType); 970 Value src = adaptor.getOperands()[0]; 971 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 972 // Allocate temporary buffers for values, filled-switch, and indices. 973 // We do not use stack buffers for this, since the expanded size may 974 // be rather large (as it envelops a single expanded dense dimension). 975 Value values = genAlloc(rewriter, loc, sz, eltType); 976 Value filled = genAlloc(rewriter, loc, sz, boolType); 977 Value indices = genAlloc(rewriter, loc, sz, idxType); 978 Value zero = constantZero(rewriter, loc, idxType); 979 // Reset the values/filled-switch to all-zero/false. Note that this 980 // introduces an O(N) operation into the computation, but this reset 981 // operation is amortized over the innermost loops for the access 982 // pattern expansion. As noted in the operation doc, we would like 983 // to amortize this setup cost even between kernels. 984 rewriter.create<linalg::FillOp>( 985 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 986 ValueRange{values}); 987 rewriter.create<linalg::FillOp>( 988 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 989 ValueRange{filled}); 990 // Replace expansion op with these buffers and initial index. 991 assert(op.getNumResults() == 4); 992 rewriter.replaceOp(op, {values, filled, indices, zero}); 993 return success(); 994 } 995 }; 996 997 /// Sparse conversion rule for the compress operator. 998 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 999 public: 1000 using OpConversionPattern::OpConversionPattern; 1001 LogicalResult 1002 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 1003 ConversionPatternRewriter &rewriter) const override { 1004 Location loc = op->getLoc(); 1005 // Note that this method call resets the values/filled-switch back to 1006 // all-zero/false by only iterating over the set elements, so the 1007 // complexity remains proportional to the sparsity of the expanded 1008 // access pattern. 1009 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 1010 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 1011 TypeRange noTp; 1012 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 1013 EmitCInterface::On); 1014 // Deallocate the buffers on exit of the loop nest. 1015 Operation *parent = op; 1016 for (; isa<scf::ForOp>(parent->getParentOp()) || 1017 isa<scf::WhileOp>(parent->getParentOp()) || 1018 isa<scf::ParallelOp>(parent->getParentOp()) || 1019 isa<scf::IfOp>(parent->getParentOp()); 1020 parent = parent->getParentOp()) 1021 ; 1022 rewriter.setInsertionPointAfter(parent); 1023 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]); 1024 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]); 1025 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]); 1026 return success(); 1027 } 1028 }; 1029 1030 /// Sparse conversion rule for the output operator. 1031 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 1032 public: 1033 using OpConversionPattern::OpConversionPattern; 1034 LogicalResult 1035 matchAndRewrite(OutOp op, OpAdaptor adaptor, 1036 ConversionPatternRewriter &rewriter) const override { 1037 Location loc = op->getLoc(); 1038 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 1039 // Convert to default permuted COO. 1040 Value src = adaptor.getOperands()[0]; 1041 auto encSrc = getSparseTensorEncoding(srcType); 1042 SmallVector<Value, 4> sizes; 1043 SmallVector<Value, 8> params; 1044 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 1045 auto enc = SparseTensorEncodingAttr::get( 1046 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 1047 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 1048 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 1049 Value coo = genNewCall(rewriter, op, params); 1050 // Then output the tensor to external file with indices in the externally 1051 // visible lexicographic index order. A sort is required if the source was 1052 // not in that order yet (note that the sort can be dropped altogether if 1053 // external format does not care about the order at all, but here we assume 1054 // it does). 1055 bool sort = 1056 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 1057 params.clear(); 1058 params.push_back(coo); 1059 params.push_back(adaptor.getOperands()[1]); 1060 params.push_back(constantI1(rewriter, loc, sort)); 1061 Type eltType = srcType.getElementType(); 1062 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 1063 TypeRange noTp; 1064 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 1065 genDelCOOCall(rewriter, op, eltType, coo); 1066 rewriter.eraseOp(op); 1067 return success(); 1068 } 1069 }; 1070 1071 } // namespace 1072 1073 //===----------------------------------------------------------------------===// 1074 // Public method for populating conversion rules. 1075 //===----------------------------------------------------------------------===// 1076 1077 /// Populates the given patterns list with conversion rules required for 1078 /// the sparsification of linear algebra operations. 1079 void mlir::populateSparseTensorConversionPatterns( 1080 TypeConverter &typeConverter, RewritePatternSet &patterns, 1081 const SparseTensorConversionOptions &options) { 1082 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 1083 SparseCastConverter, SparseTensorNewConverter, 1084 SparseReshapeConverter<tensor::ExpandShapeOp>, 1085 SparseReshapeConverter<tensor::CollapseShapeOp>, 1086 SparseTensorAllocConverter, SparseTensorReleaseConverter, 1087 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 1088 SparseTensorToValuesConverter, SparseTensorLoadConverter, 1089 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 1090 SparseTensorCompressConverter, SparseTensorOutConverter>( 1091 typeConverter, patterns.getContext()); 1092 patterns.add<SparseTensorConvertConverter>(typeConverter, 1093 patterns.getContext(), options); 1094 } 1095