1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Tensor/IR/Tensor.h" 29 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 30 #include "mlir/Transforms/DialectConversion.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 using mlir::bufferization::BufferizableOpInterface; 35 using mlir::bufferization::BufferizationDialect; 36 37 namespace { 38 39 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 40 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 41 enum class EmitCInterface : bool { Off = false, On = true }; 42 43 //===----------------------------------------------------------------------===// 44 // Helper methods. 45 //===----------------------------------------------------------------------===// 46 47 /// Returns the equivalent of `void*` for opaque arguments to the 48 /// execution engine. 49 static Type getOpaquePointerType(OpBuilder &builder) { 50 return LLVM::LLVMPointerType::get(builder.getI8Type()); 51 } 52 53 /// Returns a function reference (first hit also inserts into module). Sets 54 /// the "_emit_c_interface" on the function declaration when requested, 55 /// so that LLVM lowering generates a wrapper function that takes care 56 /// of ABI complications with passing in and returning MemRefs to C functions. 57 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 58 TypeRange resultType, ValueRange operands, 59 EmitCInterface emitCInterface) { 60 MLIRContext *context = op->getContext(); 61 auto module = op->getParentOfType<ModuleOp>(); 62 auto result = SymbolRefAttr::get(context, name); 63 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 64 if (!func) { 65 OpBuilder moduleBuilder(module.getBodyRegion()); 66 func = moduleBuilder.create<func::FuncOp>( 67 op->getLoc(), name, 68 FunctionType::get(context, operands.getTypes(), resultType)); 69 func.setPrivate(); 70 if (static_cast<bool>(emitCInterface)) 71 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 72 UnitAttr::get(context)); 73 } 74 return result; 75 } 76 77 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 78 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 79 StringRef name, TypeRange resultType, 80 ValueRange operands, 81 EmitCInterface emitCInterface) { 82 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 83 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 84 } 85 86 /// Replaces the `op` with a `CallOp` to the function reference returned 87 /// by `getFunc()`. 88 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 89 StringRef name, TypeRange resultType, 90 ValueRange operands, 91 EmitCInterface emitCInterface) { 92 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 93 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 94 operands); 95 } 96 97 /// Generates dimension size call. 98 static Value genDimSizeCall(OpBuilder &builder, Operation *op, 99 SparseTensorEncodingAttr &enc, Value src, 100 int64_t idx) { 101 // Permute the index according to an optional dimension ordering. 102 if (AffineMap p = enc.getDimOrdering()) 103 idx = p.getPermutedPosition(idx); 104 // Generate the call. 105 StringRef name = "sparseDimSize"; 106 SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)}; 107 Type iTp = builder.getIndexType(); 108 return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off) 109 .getResult(0); 110 } 111 112 /// Generates a call into the "swiss army knife" method of the sparse runtime 113 /// support library for materializing sparse tensors into the computation. 114 static Value genNewCall(OpBuilder &builder, Operation *op, 115 ArrayRef<Value> params) { 116 StringRef name = "newSparseTensor"; 117 Type pTp = getOpaquePointerType(builder); 118 return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On) 119 .getResult(0); 120 } 121 122 /// Populates given sizes array from type. 123 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes, 124 Location loc, ShapedType stp) { 125 auto shape = stp.getShape(); 126 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 127 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 128 sizes.push_back(constantIndex(builder, loc, s)); 129 } 130 } 131 132 /// Populates given sizes array from source. 133 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes, 134 Location loc, Value src) { 135 unsigned rank = src.getType().cast<ShapedType>().getRank(); 136 for (unsigned i = 0; i < rank; i++) 137 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); 138 } 139 140 /// Populates given sizes array from type (for static sizes) and from 141 /// an already converted into opague pointer source (for dynamic sizes). 142 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes, 143 Operation *op, SparseTensorEncodingAttr &enc, 144 ShapedType stp, Value src) { 145 Location loc = op->getLoc(); 146 auto shape = stp.getShape(); 147 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 148 if (shape[i] == ShapedType::kDynamicSize) 149 sizes.push_back(genDimSizeCall(builder, op, enc, src, i)); 150 else 151 sizes.push_back(constantIndex(builder, loc, shape[i])); 152 } 153 154 /// Generates an uninitialized temporary buffer of the given size and 155 /// type, but returns it as type `memref<? x $tp>` (rather than as type 156 /// `memref<$sz x $tp>`). 157 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { 158 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 159 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 160 } 161 162 /// Generates an uninitialized buffer of the given size and type, 163 /// but returns it as type `memref<? x $tp>` (rather than as type 164 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 165 /// this buffer must be explicitly deallocated by client. 166 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 167 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 168 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 169 } 170 171 /// Generates an uninitialized temporary buffer of the given size and 172 /// type, but returns it as type `memref<? x $tp>` (rather than as type 173 /// `memref<$sz x $tp>`). 174 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) { 175 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 176 } 177 178 /// Generates an uninitialized temporary buffer with room for one value 179 /// of the given type, and returns the `memref<$tp>`. 180 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { 181 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 182 } 183 184 /// Generates a temporary buffer of the given type and given contents. 185 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { 186 unsigned sz = values.size(); 187 assert(sz >= 1); 188 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 189 for (unsigned i = 0; i < sz; i++) { 190 Value idx = constantIndex(builder, loc, i); 191 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 192 } 193 return buffer; 194 } 195 196 /// Populates parameters required to call the "swiss army knife" method of the 197 /// sparse runtime support library for materializing sparse tensors into the 198 /// computation. 199 static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms, 200 Operation *op, ShapedType stp, 201 SparseTensorEncodingAttr &enc, Action action, 202 ValueRange szs, Value ptr = Value()) { 203 Location loc = op->getLoc(); 204 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 205 unsigned sz = dlt.size(); 206 // Sparsity annotations. 207 SmallVector<Value, 4> attrs; 208 for (unsigned i = 0; i < sz; i++) 209 attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); 210 params.push_back(genBuffer(builder, loc, attrs)); 211 // Dimension sizes array of the enveloping tensor. Useful for either 212 // verification of external data, or for construction of internal data. 213 params.push_back(genBuffer(builder, loc, szs)); 214 // Dimension order permutation array. This is the "identity" permutation by 215 // default, or otherwise the "reverse" permutation of a given ordering, so 216 // that indices can be mapped quickly to the right position. 217 SmallVector<Value, 4> rev(sz); 218 if (AffineMap p = enc.getDimOrdering()) { 219 for (unsigned i = 0; i < sz; i++) 220 rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); 221 } else { 222 for (unsigned i = 0; i < sz; i++) 223 rev[i] = constantIndex(builder, loc, i); 224 } 225 params.push_back(genBuffer(builder, loc, rev)); 226 // Secondary and primary types encoding. 227 Type elemTp = stp.getElementType(); 228 params.push_back(constantPointerTypeEncoding(builder, loc, enc)); 229 params.push_back(constantIndexTypeEncoding(builder, loc, enc)); 230 params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); 231 // User action. 232 params.push_back(constantAction(builder, loc, action)); 233 // Payload pointer. 234 if (!ptr) 235 ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder)); 236 params.push_back(ptr); 237 } 238 239 /// Generates the code to read the value from tensor[ivs], and conditionally 240 /// stores the indices ivs to the memory in ind. The generated code looks like 241 /// the following and the insertion point after this routine is inside the 242 /// if-then branch behind the assignment to ind. This is to ensure that the 243 /// addEltX call generated after is inside the if-then branch. 244 /// if (tensor[ivs] != 0) 245 /// ind = ivs 246 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, 247 Value tensor, Value ind, ValueRange ivs) { 248 Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs); 249 Value cond = genIsNonzero(builder, loc, val); 250 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false); 251 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 252 unsigned i = 0; 253 for (auto iv : ivs) { 254 Value idx = constantIndex(builder, loc, i++); 255 builder.create<memref::StoreOp>(loc, iv, ind, idx); 256 } 257 return val; 258 } 259 260 /// Generates a call to release/delete a `SparseTensorCOO`. 261 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 262 Value coo) { 263 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 264 TypeRange noTp; 265 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 266 } 267 268 /// Generates a call that adds one element to a coordinate scheme. 269 /// In particular, this generates code like the following: 270 /// val = a[i1,..,ik]; 271 /// if val != 0 272 /// t->add(&val, [i1,..,ik], [p1,..,pk]); 273 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType, 274 Value ptr, Value valPtr, Value ind, Value perm) { 275 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 276 SmallVector<Value, 4> params{ptr, valPtr, ind, perm}; 277 Type pTp = getOpaquePointerType(builder); 278 createFuncCall(builder, op, name, pTp, params, EmitCInterface::On); 279 } 280 281 /// Generates a call to `iter->getNext()`. If there is a next element, 282 /// then it is copied into the out-parameters `ind` and `elemPtr`, 283 /// and the return value is true. If there isn't a next element, then 284 /// the memory for `iter` is freed and the return value is false. 285 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter, 286 Value ind, Value elemPtr) { 287 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 288 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 289 SmallVector<Value, 3> params{iter, ind, elemPtr}; 290 Type i1 = builder.getI1Type(); 291 return createFuncCall(builder, op, name, i1, params, EmitCInterface::On) 292 .getResult(0); 293 } 294 295 /// If the tensor is a sparse constant, generates and returns the pair of 296 /// the constants for the indices and the values. 297 static Optional<std::pair<Value, Value>> 298 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { 299 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 300 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 301 DenseElementsAttr indicesAttr = attr.getIndices(); 302 Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr); 303 DenseElementsAttr valuesAttr = attr.getValues(); 304 Value values = builder.create<arith::ConstantOp>(loc, valuesAttr); 305 return std::make_pair(indices, values); 306 } 307 } 308 return {}; 309 } 310 311 /// Generates the code to copy the index at indices[ivs] to ind, and return 312 /// the value at value[ivs]. 313 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc, 314 Value indices, Value values, Value ind, 315 ValueRange ivs, unsigned rank) { 316 for (unsigned i = 0; i < rank; i++) { 317 Value idx = constantIndex(builder, loc, i); 318 Value val = builder.create<tensor::ExtractOp>(loc, indices, 319 ValueRange{ivs[0], idx}); 320 val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val); 321 builder.create<memref::StoreOp>(loc, val, ind, idx); 322 } 323 return builder.create<tensor::ExtractOp>(loc, values, ivs[0]); 324 } 325 326 /// Generates code to allocate a buffer of the given type, and zero 327 /// initialize it. If the buffer type has any dynamic sizes, then the 328 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 329 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 330 static Value allocDenseTensor(OpBuilder &builder, Location loc, 331 RankedTensorType tensorTp, ValueRange sizes) { 332 Type elemTp = tensorTp.getElementType(); 333 auto shape = tensorTp.getShape(); 334 auto memTp = MemRefType::get(shape, elemTp); 335 SmallVector<Value> dynamicSizes; 336 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 337 if (shape[i] == ShapedType::kDynamicSize) 338 dynamicSizes.push_back(sizes[i]); 339 } 340 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 341 Value zero = constantZero(builder, loc, elemTp); 342 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 343 return mem; 344 } 345 346 /// Generates code to deallocate a dense buffer. 347 static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) { 348 builder.create<memref::DeallocOp>(loc, buffer); 349 } 350 351 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 352 /// the tensor created by allocDenseTensor(). The `rank` is the rank 353 /// of the `tensor` and the length of `ind`. 354 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, 355 Value elemPtr, Value tensor, 356 unsigned rank, Value ind) { 357 SmallVector<Value, 4> ivs; 358 ivs.reserve(rank); 359 for (unsigned i = 0; i < rank; i++) { 360 Value idx = constantIndex(builder, loc, i); 361 ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx)); 362 } 363 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr); 364 builder.create<memref::StoreOp>(loc, elemV, tensor, ivs); 365 } 366 367 /// Determine if the runtime library supports direct conversion to the 368 /// given target `dimTypes`. 369 static bool canUseDirectConversion( 370 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) { 371 bool alreadyCompressed = false; 372 for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { 373 switch (dimTypes[r]) { 374 case SparseTensorEncodingAttr::DimLevelType::Compressed: 375 if (alreadyCompressed) 376 return false; // Multiple compressed dimensions not yet supported. 377 alreadyCompressed = true; 378 break; 379 case SparseTensorEncodingAttr::DimLevelType::Dense: 380 if (alreadyCompressed) 381 return false; // Dense after Compressed not yet supported. 382 break; 383 case SparseTensorEncodingAttr::DimLevelType::Singleton: 384 // Although Singleton isn't generally supported yet, the direct 385 // conversion method doesn't have any particular problems with 386 // singleton after compressed. 387 break; 388 } 389 } 390 return true; 391 } 392 393 /// Helper method to translate indices during a reshaping operation. 394 /// TODO: provide as general utility to MLIR at large? 395 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, 396 ArrayRef<ReassociationIndices> reassociation, 397 TensorType dstTp, TensorType srcTp, Value dstIdx, 398 Value srcIdx) { 399 unsigned dstRank = dstTp.getRank(); 400 unsigned srcRank = srcTp.getRank(); 401 unsigned start = 0; 402 unsigned i = 0; 403 bool isExpand = srcRank > dstRank; 404 ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape(); 405 // Iterate over reassociation map. 406 for (const auto &map : llvm::enumerate(reassociation)) { 407 // Prepare strides information in dimension slice. 408 uint64_t linear = 1; 409 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 410 assert(!ShapedType::isDynamic(shape[j])); 411 linear *= shape[j]; 412 } 413 // Start collapse. 414 Value idx = constantIndex(rewriter, loc, i++); 415 Value val; 416 if (!isExpand) 417 val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx); 418 // Iterate over dimension slice. 419 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 420 linear /= shape[j]; 421 Value stride = constantIndex(rewriter, loc, linear); 422 Value jdx = constantIndex(rewriter, loc, j); 423 if (isExpand) { 424 Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx); 425 Value mul = linear == 1 426 ? old 427 : rewriter.create<arith::MulIOp>(loc, old, stride); 428 val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul; 429 } else { 430 Value old = val; 431 if (linear != 1) 432 val = rewriter.create<arith::DivUIOp>(loc, val, stride); 433 rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx); 434 if (linear != 1) 435 val = rewriter.create<arith::RemUIOp>(loc, old, stride); 436 } 437 } 438 // Finalize expansion. 439 if (isExpand) 440 rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx); 441 start += map.value().size(); 442 } 443 // Sanity. 444 assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); 445 } 446 447 /// Generate code for a general sparse to sparse reshaping operation. 448 /// Note that unlike dense reshaping (which can be done with a "cheap" 449 /// change of view), sparse reshaping is currently done with actual 450 /// data shuffling. 451 /// 452 /// TODO: proportional to nnz, but still a lot of data movement 453 /// https://github.com/llvm/llvm-project/issues/56477 454 /// 455 /// iter = src->toCOO(); 456 /// coo = newSparseCOO() 457 /// while (elem = iter->getNext()) { 458 /// coo->add(reshape(elem.indices), elem.value) 459 /// } 460 /// s = newSparseTensor(coo) 461 static LogicalResult 462 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter, 463 ArrayRef<ReassociationIndices> reassociation, Value src, 464 RankedTensorType dstTp, RankedTensorType srcTp) { 465 Location loc = op->getLoc(); 466 auto encDst = getSparseTensorEncoding(dstTp); 467 auto encSrc = getSparseTensorEncoding(srcTp); 468 assert(encDst && encSrc); 469 unsigned srcRank = srcTp.getRank(); 470 unsigned dstRank = dstTp.getRank(); 471 Type elemTp = srcTp.getElementType(); 472 assert(elemTp == dstTp.getElementType() && 473 "reshape should not change element type"); 474 // Start an iterator over the source tensor (in original index order). 475 auto noPerm = SparseTensorEncodingAttr::get( 476 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 477 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 478 SmallVector<Value, 4> sizes; 479 SmallVector<Value, 8> params; 480 sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); 481 newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, 482 src); 483 Value iter = genNewCall(rewriter, op, params); 484 // Start a new COO for the destination tensor. 485 sizes.clear(); 486 params.clear(); 487 sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); 488 newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); 489 Value coo = genNewCall(rewriter, op, params); 490 Value dstPerm = params[2]; 491 // Construct a while loop over the iterator. 492 Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType()); 493 Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType()); 494 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 495 SmallVector<Value> noArgs; 496 SmallVector<Type> noTypes; 497 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 498 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 499 rewriter.setInsertionPointToEnd(before); 500 Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr); 501 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 502 // Translate indices from source to target and insert. Note that we do 503 // not need to store the value in elemPtr, as the value is still there. 504 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 505 rewriter.setInsertionPointToStart(after); 506 translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); 507 genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); 508 rewriter.create<scf::YieldOp>(loc); 509 // Final call to construct sparse tensor storage and free temporary resources. 510 rewriter.setInsertionPointAfter(whileOp); 511 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 512 params[7] = coo; 513 Value dst = genNewCall(rewriter, op, params); 514 genDelCOOCall(rewriter, op, elemTp, coo); 515 genDelCOOCall(rewriter, op, elemTp, iter); 516 rewriter.replaceOp(op, dst); 517 return success(); 518 } 519 520 //===----------------------------------------------------------------------===// 521 // Conversion rules. 522 //===----------------------------------------------------------------------===// 523 524 /// Sparse conversion rule for returns. 525 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 526 public: 527 using OpConversionPattern::OpConversionPattern; 528 LogicalResult 529 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 530 ConversionPatternRewriter &rewriter) const override { 531 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 532 return success(); 533 } 534 }; 535 536 /// Sparse conversion rule for dimension accesses. 537 class SparseTensorToDimSizeConverter 538 : public OpConversionPattern<tensor::DimOp> { 539 public: 540 using OpConversionPattern::OpConversionPattern; 541 LogicalResult 542 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 543 ConversionPatternRewriter &rewriter) const override { 544 // Only rewrite annotated DimOp with constant index. 545 auto enc = getSparseTensorEncoding(op.getSource().getType()); 546 if (!enc) 547 return failure(); 548 Optional<int64_t> index = op.getConstantIndex(); 549 if (!index) 550 return failure(); 551 // Generate the call. 552 Value src = adaptor.getOperands()[0]; 553 int64_t idx = *index; 554 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 555 return success(); 556 } 557 }; 558 559 /// Sparse conversion rule for trivial tensor casts. 560 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 561 public: 562 using OpConversionPattern::OpConversionPattern; 563 LogicalResult 564 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 565 ConversionPatternRewriter &rewriter) const override { 566 // Only rewrite identically annotated source/dest. 567 auto encDst = getSparseTensorEncoding(op.getType()); 568 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 569 if (!encDst || encDst != encSrc) 570 return failure(); 571 rewriter.replaceOp(op, adaptor.getOperands()); 572 return success(); 573 } 574 }; 575 576 /// Sparse conversion rule for a reshape operator. 577 template <typename ReshapeOp> 578 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> { 579 public: 580 using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor; 581 using OpConversionPattern<ReshapeOp>::OpConversionPattern; 582 LogicalResult 583 matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, 584 ConversionPatternRewriter &rewriter) const override { 585 Type dstType = op.getResult().getType(); 586 Type srcType = op.getSrc().getType(); 587 auto encDst = getSparseTensorEncoding(dstType); 588 auto encSrc = getSparseTensorEncoding(srcType); 589 if (encDst && encSrc) 590 return genSparse2SparseReshape( 591 op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0], 592 dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>()); 593 return failure(); // handled elsewhere 594 } 595 }; 596 597 /// Sparse conversion rule for the new operator. 598 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 599 public: 600 using OpConversionPattern::OpConversionPattern; 601 LogicalResult 602 matchAndRewrite(NewOp op, OpAdaptor adaptor, 603 ConversionPatternRewriter &rewriter) const override { 604 Type resType = op.getType(); 605 auto enc = getSparseTensorEncoding(resType); 606 if (!enc) 607 return failure(); 608 // Generate the call to construct tensor from ptr. The sizes are 609 // inferred from the result type of the new operator. 610 SmallVector<Value, 4> sizes; 611 SmallVector<Value, 8> params; 612 ShapedType stp = resType.cast<ShapedType>(); 613 sizesFromType(rewriter, sizes, op.getLoc(), stp); 614 Value ptr = adaptor.getOperands()[0]; 615 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 616 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 617 return success(); 618 } 619 }; 620 621 /// Sparse conversion rule for the alloc operator. 622 class SparseTensorAllocConverter 623 : public OpConversionPattern<bufferization::AllocTensorOp> { 624 public: 625 using OpConversionPattern::OpConversionPattern; 626 LogicalResult 627 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 628 ConversionPatternRewriter &rewriter) const override { 629 if (op.getCopy()) 630 return rewriter.notifyMatchFailure(op, 631 "sparse tensor copy not implemented"); 632 RankedTensorType resType = op.getType(); 633 auto enc = getSparseTensorEncoding(resType); 634 if (!enc) 635 return failure(); 636 // Gather all dimension sizes as SSA values. 637 SmallVector<Value> sizes; 638 unsigned int operandCtr = 0; 639 for (int64_t i = 0; i < resType.getRank(); ++i) { 640 if (resType.isDynamicDim(i)) { 641 sizes.push_back(adaptor.getOperands()[operandCtr++]); 642 } else { 643 sizes.push_back(rewriter.create<arith::ConstantIndexOp>( 644 op.getLoc(), op.getStaticSize(i))); 645 } 646 } 647 // Generate the call to construct empty tensor. The sizes are 648 // explicitly defined by the arguments to the alloc operator. 649 SmallVector<Value, 8> params; 650 ShapedType stp = resType.cast<ShapedType>(); 651 newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes); 652 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 653 return success(); 654 } 655 }; 656 657 /// Sparse conversion rule for the convert operator. 658 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 SparseTensorConvertConverter(MLIRContext *context, 662 SparseTensorConversionOptions o) 663 : OpConversionPattern<ConvertOp>(context), options(o) {} 664 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 665 SparseTensorConversionOptions o) 666 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 667 668 LogicalResult 669 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 Location loc = op->getLoc(); 672 Type resType = op.getType(); 673 Type srcType = op.getSource().getType(); 674 auto encDst = getSparseTensorEncoding(resType); 675 auto encSrc = getSparseTensorEncoding(srcType); 676 Value src = adaptor.getOperands()[0]; 677 if (encDst && encSrc) { 678 // This is a sparse => sparse conversion, which is handled as follows: 679 // t = src->toCOO(); ; src to COO in dst order 680 // dst = newSparseTensor(t) 681 // Using the coordinate scheme as an intermediate does not always 682 // yield the fastest conversion but avoids the need for a full 683 // O(N^2) conversion matrix. 684 if (encDst == encSrc) { 685 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 686 return success(); 687 } 688 SmallVector<Value, 4> sizes; 689 SmallVector<Value, 8> params; 690 ShapedType stp = srcType.cast<ShapedType>(); 691 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 692 bool useDirectConversion; 693 switch (options.sparseToSparseStrategy) { 694 case SparseToSparseConversionStrategy::kViaCOO: 695 useDirectConversion = false; 696 break; 697 case SparseToSparseConversionStrategy::kDirect: 698 useDirectConversion = true; 699 assert(canUseDirectConversion(encDst.getDimLevelType()) && 700 "Unsupported target for direct sparse-to-sparse conversion"); 701 break; 702 case SparseToSparseConversionStrategy::kAuto: 703 useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); 704 break; 705 } 706 if (useDirectConversion) { 707 newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse, 708 sizes, src); 709 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 710 } else { // use via-COO conversion. 711 // Set up encoding with right mix of src and dst so that the two 712 // method calls can share most parameters, while still providing 713 // the correct sparsity information to either of them. 714 auto enc = SparseTensorEncodingAttr::get( 715 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 716 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 717 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 718 Value coo = genNewCall(rewriter, op, params); 719 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 720 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 721 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 722 params[7] = coo; 723 Value dst = genNewCall(rewriter, op, params); 724 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 725 rewriter.replaceOp(op, dst); 726 } 727 return success(); 728 } 729 if (!encDst && encSrc) { 730 // This is sparse => dense conversion, which is handled as follows: 731 // dst = new Tensor(0); 732 // iter = src->toCOO(); 733 // iter->startIterator(); 734 // while (elem = iter->getNext()) { 735 // dst[elem.indices] = elem.value; 736 // } 737 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 738 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 739 unsigned rank = dstTensorTp.getRank(); 740 Type elemTp = dstTensorTp.getElementType(); 741 // Fabricate a no-permutation encoding for newParams(). 742 // The pointer/index types must be those of `src`. 743 // The dimLevelTypes aren't actually used by Action::kToIterator. 744 encDst = SparseTensorEncodingAttr::get( 745 op->getContext(), 746 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 747 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 748 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 749 SmallVector<Value, 4> sizes; 750 SmallVector<Value, 8> params; 751 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 752 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 753 sizes, src); 754 Value iter = genNewCall(rewriter, op, params); 755 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 756 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 757 Block *insertionBlock = rewriter.getInsertionBlock(); 758 // TODO: Dense buffers should be allocated/deallocated via the callback 759 // in BufferizationOptions. 760 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 761 SmallVector<Value> noArgs; 762 SmallVector<Type> noTypes; 763 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 764 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 765 rewriter.setInsertionPointToEnd(before); 766 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 767 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 768 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 769 rewriter.setInsertionPointToStart(after); 770 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 771 rewriter.create<scf::YieldOp>(loc); 772 rewriter.setInsertionPointAfter(whileOp); 773 genDelCOOCall(rewriter, op, elemTp, iter); 774 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 775 // Deallocate the buffer. 776 if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { 777 rewriter.setInsertionPoint(insertionBlock->getTerminator()); 778 deallocDenseTensor(rewriter, loc, dst); 779 } 780 return success(); 781 } 782 if (!encDst && !encSrc) { 783 // dense => dense 784 return failure(); 785 } 786 // This is a dense => sparse conversion or a sparse constant in COO => 787 // sparse conversion, which is handled as follows: 788 // t = newSparseCOO() 789 // ...code to fill the COO tensor t... 790 // s = newSparseTensor(t) 791 // 792 // To fill the COO tensor from a dense tensor: 793 // for i1 in dim1 794 // .. 795 // for ik in dimk 796 // val = a[i1,..,ik] 797 // if val != 0 798 // t->add(val, [i1,..,ik], [p1,..,pk]) 799 // 800 // To fill the COO tensor from a sparse constant in COO format: 801 // for i in range(NNZ) 802 // val = values[i] 803 // [i1,..,ik] = indices[i] 804 // t->add(val, [i1,..,ik], [p1,..,pk]) 805 // 806 // Note that the dense tensor traversal code is actually implemented 807 // using MLIR IR to avoid having to expose too much low-level 808 // memref traversal details to the runtime support library. 809 // Also note that the code below only generates the "new" ops and 810 // the loop-nest per se; whereas the entire body of the innermost 811 // loop is generated by genAddElt(). 812 ShapedType stp = resType.cast<ShapedType>(); 813 unsigned rank = stp.getRank(); 814 SmallVector<Value, 4> sizes; 815 SmallVector<Value, 8> params; 816 sizesFromSrc(rewriter, sizes, loc, src); 817 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 818 Value coo = genNewCall(rewriter, op, params); 819 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 820 Value perm = params[2]; 821 SmallVector<Value> lo; 822 SmallVector<Value> hi; 823 SmallVector<Value> st; 824 Value zero = constantIndex(rewriter, loc, 0); 825 Value one = constantIndex(rewriter, loc, 1); 826 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 827 bool isCOOConstant = indicesValues.has_value(); 828 Value indices; 829 Value values; 830 if (isCOOConstant) { 831 indices = indicesValues->first; 832 values = indicesValues->second; 833 lo.push_back(zero); 834 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 835 st.push_back(one); 836 } else { 837 for (unsigned i = 0; i < rank; i++) { 838 lo.push_back(zero); 839 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 840 st.push_back(one); 841 } 842 } 843 Type eltType = stp.getElementType(); 844 Value elemPtr = genAllocaScalar(rewriter, loc, eltType); 845 scf::buildLoopNest( 846 rewriter, op.getLoc(), lo, hi, st, {}, 847 [&](OpBuilder &builder, Location loc, ValueRange ivs, 848 ValueRange args) -> scf::ValueVector { 849 Value val; 850 if (isCOOConstant) 851 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 852 ivs, rank); 853 else 854 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 855 builder.create<memref::StoreOp>(loc, val, elemPtr); 856 genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm); 857 return {}; 858 }); 859 // Final call to construct sparse tensor storage. 860 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 861 params[7] = coo; 862 Value dst = genNewCall(rewriter, op, params); 863 genDelCOOCall(rewriter, op, eltType, coo); 864 rewriter.replaceOp(op, dst); 865 return success(); 866 } 867 868 private: 869 /// Options to control sparse code generation. 870 SparseTensorConversionOptions options; 871 }; 872 873 /// Sparse conversion rule for the release operator. 874 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 875 public: 876 using OpConversionPattern::OpConversionPattern; 877 LogicalResult 878 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 879 ConversionPatternRewriter &rewriter) const override { 880 StringRef name = "delSparseTensor"; 881 TypeRange noTp; 882 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 883 EmitCInterface::Off); 884 rewriter.eraseOp(op); 885 return success(); 886 } 887 }; 888 889 /// Sparse conversion rule for pointer accesses. 890 class SparseTensorToPointersConverter 891 : public OpConversionPattern<ToPointersOp> { 892 public: 893 using OpConversionPattern::OpConversionPattern; 894 LogicalResult 895 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 896 ConversionPatternRewriter &rewriter) const override { 897 Type resType = op.getType(); 898 Type ptrType = resType.cast<ShapedType>().getElementType(); 899 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 900 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 901 EmitCInterface::On); 902 return success(); 903 } 904 }; 905 906 /// Sparse conversion rule for index accesses. 907 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 908 public: 909 using OpConversionPattern::OpConversionPattern; 910 LogicalResult 911 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 912 ConversionPatternRewriter &rewriter) const override { 913 Type resType = op.getType(); 914 Type indType = resType.cast<ShapedType>().getElementType(); 915 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 916 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 917 EmitCInterface::On); 918 return success(); 919 } 920 }; 921 922 /// Sparse conversion rule for value accesses. 923 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 924 public: 925 using OpConversionPattern::OpConversionPattern; 926 LogicalResult 927 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 928 ConversionPatternRewriter &rewriter) const override { 929 Type resType = op.getType(); 930 Type eltType = resType.cast<ShapedType>().getElementType(); 931 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 932 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 933 EmitCInterface::On); 934 return success(); 935 } 936 }; 937 938 /// Sparse conversion rule for tensor rematerialization. 939 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 940 public: 941 using OpConversionPattern::OpConversionPattern; 942 LogicalResult 943 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 944 ConversionPatternRewriter &rewriter) const override { 945 if (op.getHasInserts()) { 946 // Finalize any pending insertions. 947 StringRef name = "endInsert"; 948 TypeRange noTp; 949 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 950 EmitCInterface::Off); 951 } 952 rewriter.replaceOp(op, adaptor.getOperands()); 953 return success(); 954 } 955 }; 956 957 /// Sparse conversion rule for inserting in lexicographic index order. 958 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 959 public: 960 using OpConversionPattern::OpConversionPattern; 961 LogicalResult 962 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 963 ConversionPatternRewriter &rewriter) const override { 964 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 965 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 966 TypeRange noTp; 967 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 968 EmitCInterface::On); 969 return success(); 970 } 971 }; 972 973 /// Sparse conversion rule for the expand operator. 974 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 975 public: 976 using OpConversionPattern::OpConversionPattern; 977 LogicalResult 978 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 979 ConversionPatternRewriter &rewriter) const override { 980 Location loc = op->getLoc(); 981 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 982 Type eltType = srcType.getElementType(); 983 Type boolType = rewriter.getIntegerType(1); 984 Type idxType = rewriter.getIndexType(); 985 // All initialization should be done on entry of the loop nest. 986 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 987 // Determine the size for access expansion. 988 auto enc = getSparseTensorEncoding(srcType); 989 Value src = adaptor.getOperands()[0]; 990 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 991 // Allocate temporary buffers for values, filled-switch, and indices. 992 // We do not use stack buffers for this, since the expanded size may 993 // be rather large (as it envelops a single expanded dense dimension). 994 Value values = genAlloc(rewriter, loc, sz, eltType); 995 Value filled = genAlloc(rewriter, loc, sz, boolType); 996 Value indices = genAlloc(rewriter, loc, sz, idxType); 997 Value zero = constantZero(rewriter, loc, idxType); 998 // Reset the values/filled-switch to all-zero/false. Note that this 999 // introduces an O(N) operation into the computation, but this reset 1000 // operation is amortized over the innermost loops for the access 1001 // pattern expansion. As noted in the operation doc, we would like 1002 // to amortize this setup cost even between kernels. 1003 rewriter.create<linalg::FillOp>( 1004 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 1005 ValueRange{values}); 1006 rewriter.create<linalg::FillOp>( 1007 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 1008 ValueRange{filled}); 1009 // Replace expansion op with these buffers and initial index. 1010 assert(op.getNumResults() == 4); 1011 rewriter.replaceOp(op, {values, filled, indices, zero}); 1012 return success(); 1013 } 1014 }; 1015 1016 /// Sparse conversion rule for the compress operator. 1017 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 1018 public: 1019 using OpConversionPattern::OpConversionPattern; 1020 LogicalResult 1021 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 1022 ConversionPatternRewriter &rewriter) const override { 1023 Location loc = op->getLoc(); 1024 // Note that this method call resets the values/filled-switch back to 1025 // all-zero/false by only iterating over the set elements, so the 1026 // complexity remains proportional to the sparsity of the expanded 1027 // access pattern. 1028 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 1029 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 1030 TypeRange noTp; 1031 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 1032 EmitCInterface::On); 1033 // Deallocate the buffers on exit of the loop nest. 1034 Operation *parent = op; 1035 for (; isa<scf::ForOp>(parent->getParentOp()) || 1036 isa<scf::WhileOp>(parent->getParentOp()) || 1037 isa<scf::ParallelOp>(parent->getParentOp()) || 1038 isa<scf::IfOp>(parent->getParentOp()); 1039 parent = parent->getParentOp()) 1040 ; 1041 rewriter.setInsertionPointAfter(parent); 1042 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]); 1043 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]); 1044 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]); 1045 return success(); 1046 } 1047 }; 1048 1049 /// Sparse conversion rule for the output operator. 1050 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 1051 public: 1052 using OpConversionPattern::OpConversionPattern; 1053 LogicalResult 1054 matchAndRewrite(OutOp op, OpAdaptor adaptor, 1055 ConversionPatternRewriter &rewriter) const override { 1056 Location loc = op->getLoc(); 1057 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 1058 // Convert to default permuted COO. 1059 Value src = adaptor.getOperands()[0]; 1060 auto encSrc = getSparseTensorEncoding(srcType); 1061 SmallVector<Value, 4> sizes; 1062 SmallVector<Value, 8> params; 1063 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 1064 auto enc = SparseTensorEncodingAttr::get( 1065 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 1066 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 1067 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 1068 Value coo = genNewCall(rewriter, op, params); 1069 // Then output the tensor to external file with indices in the externally 1070 // visible lexicographic index order. A sort is required if the source was 1071 // not in that order yet (note that the sort can be dropped altogether if 1072 // external format does not care about the order at all, but here we assume 1073 // it does). 1074 bool sort = 1075 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 1076 params.clear(); 1077 params.push_back(coo); 1078 params.push_back(adaptor.getOperands()[1]); 1079 params.push_back(constantI1(rewriter, loc, sort)); 1080 Type eltType = srcType.getElementType(); 1081 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 1082 TypeRange noTp; 1083 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 1084 genDelCOOCall(rewriter, op, eltType, coo); 1085 rewriter.eraseOp(op); 1086 return success(); 1087 } 1088 }; 1089 1090 } // namespace 1091 1092 //===----------------------------------------------------------------------===// 1093 // Public method for populating conversion rules. 1094 //===----------------------------------------------------------------------===// 1095 1096 /// Populates the given patterns list with conversion rules required for 1097 /// the sparsification of linear algebra operations. 1098 void mlir::populateSparseTensorConversionPatterns( 1099 TypeConverter &typeConverter, RewritePatternSet &patterns, 1100 const SparseTensorConversionOptions &options) { 1101 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 1102 SparseCastConverter, SparseTensorNewConverter, 1103 SparseReshapeConverter<tensor::ExpandShapeOp>, 1104 SparseReshapeConverter<tensor::CollapseShapeOp>, 1105 SparseTensorAllocConverter, SparseTensorReleaseConverter, 1106 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 1107 SparseTensorToValuesConverter, SparseTensorLoadConverter, 1108 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 1109 SparseTensorCompressConverter, SparseTensorOutConverter>( 1110 typeConverter, patterns.getContext()); 1111 patterns.add<SparseTensorConvertConverter>(typeConverter, 1112 patterns.getContext(), options); 1113 } 1114