1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "CodegenUtils.h" 18 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Tensor/IR/Tensor.h" 29 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 30 #include "mlir/Transforms/DialectConversion.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 namespace { 36 37 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 38 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 39 enum class EmitCInterface : bool { Off = false, On = true }; 40 41 //===----------------------------------------------------------------------===// 42 // Helper methods. 43 //===----------------------------------------------------------------------===// 44 45 /// Returns the equivalent of `void*` for opaque arguments to the 46 /// execution engine. 47 static Type getOpaquePointerType(OpBuilder &builder) { 48 return LLVM::LLVMPointerType::get(builder.getI8Type()); 49 } 50 51 /// Returns a function reference (first hit also inserts into module). Sets 52 /// the "_emit_c_interface" on the function declaration when requested, 53 /// so that LLVM lowering generates a wrapper function that takes care 54 /// of ABI complications with passing in and returning MemRefs to C functions. 55 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 56 TypeRange resultType, ValueRange operands, 57 EmitCInterface emitCInterface) { 58 MLIRContext *context = op->getContext(); 59 auto module = op->getParentOfType<ModuleOp>(); 60 auto result = SymbolRefAttr::get(context, name); 61 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 62 if (!func) { 63 OpBuilder moduleBuilder(module.getBodyRegion()); 64 func = moduleBuilder.create<func::FuncOp>( 65 op->getLoc(), name, 66 FunctionType::get(context, operands.getTypes(), resultType)); 67 func.setPrivate(); 68 if (static_cast<bool>(emitCInterface)) 69 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 70 UnitAttr::get(context)); 71 } 72 return result; 73 } 74 75 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 76 static func::CallOp createFuncCall(OpBuilder &builder, Operation *op, 77 StringRef name, TypeRange resultType, 78 ValueRange operands, 79 EmitCInterface emitCInterface) { 80 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 81 return builder.create<func::CallOp>(op->getLoc(), resultType, fn, operands); 82 } 83 84 /// Replaces the `op` with a `CallOp` to the function reference returned 85 /// by `getFunc()`. 86 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 87 StringRef name, TypeRange resultType, 88 ValueRange operands, 89 EmitCInterface emitCInterface) { 90 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 91 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 92 operands); 93 } 94 95 /// Generates dimension size call. 96 static Value genDimSizeCall(OpBuilder &builder, Operation *op, 97 SparseTensorEncodingAttr &enc, Value src, 98 int64_t idx) { 99 // Permute the index according to an optional dimension ordering. 100 if (AffineMap p = enc.getDimOrdering()) 101 idx = p.getPermutedPosition(idx); 102 // Generate the call. 103 StringRef name = "sparseDimSize"; 104 SmallVector<Value, 2> params{src, constantIndex(builder, op->getLoc(), idx)}; 105 Type iTp = builder.getIndexType(); 106 return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off) 107 .getResult(0); 108 } 109 110 /// Generates a call into the "swiss army knife" method of the sparse runtime 111 /// support library for materializing sparse tensors into the computation. 112 static Value genNewCall(OpBuilder &builder, Operation *op, 113 ArrayRef<Value> params) { 114 StringRef name = "newSparseTensor"; 115 Type pTp = getOpaquePointerType(builder); 116 return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On) 117 .getResult(0); 118 } 119 120 /// Populates given sizes array from type. 121 static void sizesFromType(OpBuilder &builder, SmallVector<Value, 4> &sizes, 122 Location loc, ShapedType stp) { 123 auto shape = stp.getShape(); 124 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 125 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 126 sizes.push_back(constantIndex(builder, loc, s)); 127 } 128 } 129 130 /// Populates given sizes array from source. 131 static void sizesFromSrc(OpBuilder &builder, SmallVector<Value, 4> &sizes, 132 Location loc, Value src) { 133 unsigned rank = src.getType().cast<ShapedType>().getRank(); 134 for (unsigned i = 0; i < rank; i++) 135 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); 136 } 137 138 /// Populates given sizes array from type (for static sizes) and from 139 /// an already converted into opague pointer source (for dynamic sizes). 140 static void sizesFromPtr(OpBuilder &builder, SmallVector<Value, 4> &sizes, 141 Operation *op, SparseTensorEncodingAttr &enc, 142 ShapedType stp, Value src) { 143 Location loc = op->getLoc(); 144 auto shape = stp.getShape(); 145 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 146 if (shape[i] == ShapedType::kDynamicSize) 147 sizes.push_back(genDimSizeCall(builder, op, enc, src, i)); 148 else 149 sizes.push_back(constantIndex(builder, loc, shape[i])); 150 } 151 152 /// Generates an uninitialized temporary buffer of the given size and 153 /// type, but returns it as type `memref<? x $tp>` (rather than as type 154 /// `memref<$sz x $tp>`). 155 static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { 156 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 157 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 158 } 159 160 /// Generates an uninitialized buffer of the given size and type, 161 /// but returns it as type `memref<? x $tp>` (rather than as type 162 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 163 /// this buffer must be explicitly deallocated by client. 164 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 165 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 166 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 167 } 168 169 /// Generates an uninitialized temporary buffer of the given size and 170 /// type, but returns it as type `memref<? x $tp>` (rather than as type 171 /// `memref<$sz x $tp>`). 172 static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) { 173 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 174 } 175 176 /// Generates an uninitialized temporary buffer with room for one value 177 /// of the given type, and returns the `memref<$tp>`. 178 static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { 179 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 180 } 181 182 /// Generates a temporary buffer of the given type and given contents. 183 static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { 184 unsigned sz = values.size(); 185 assert(sz >= 1); 186 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 187 for (unsigned i = 0; i < sz; i++) { 188 Value idx = constantIndex(builder, loc, i); 189 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 190 } 191 return buffer; 192 } 193 194 /// Populates parameters required to call the "swiss army knife" method of the 195 /// sparse runtime support library for materializing sparse tensors into the 196 /// computation. 197 static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms, 198 Operation *op, ShapedType stp, 199 SparseTensorEncodingAttr &enc, Action action, 200 ValueRange szs, Value ptr = Value()) { 201 Location loc = op->getLoc(); 202 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 203 unsigned sz = dlt.size(); 204 // Sparsity annotations. 205 SmallVector<Value, 4> attrs; 206 for (unsigned i = 0; i < sz; i++) 207 attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); 208 params.push_back(genBuffer(builder, loc, attrs)); 209 // Dimension sizes array of the enveloping tensor. Useful for either 210 // verification of external data, or for construction of internal data. 211 params.push_back(genBuffer(builder, loc, szs)); 212 // Dimension order permutation array. This is the "identity" permutation by 213 // default, or otherwise the "reverse" permutation of a given ordering, so 214 // that indices can be mapped quickly to the right position. 215 SmallVector<Value, 4> rev(sz); 216 if (AffineMap p = enc.getDimOrdering()) { 217 for (unsigned i = 0; i < sz; i++) 218 rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); 219 } else { 220 for (unsigned i = 0; i < sz; i++) 221 rev[i] = constantIndex(builder, loc, i); 222 } 223 params.push_back(genBuffer(builder, loc, rev)); 224 // Secondary and primary types encoding. 225 Type elemTp = stp.getElementType(); 226 params.push_back(constantPointerTypeEncoding(builder, loc, enc)); 227 params.push_back(constantIndexTypeEncoding(builder, loc, enc)); 228 params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); 229 // User action. 230 params.push_back(constantAction(builder, loc, action)); 231 // Payload pointer. 232 if (!ptr) 233 ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder)); 234 params.push_back(ptr); 235 } 236 237 /// Generates the code to read the value from tensor[ivs], and conditionally 238 /// stores the indices ivs to the memory in ind. The generated code looks like 239 /// the following and the insertion point after this routine is inside the 240 /// if-then branch behind the assignment to ind. This is to ensure that the 241 /// addEltX call generated after is inside the if-then branch. 242 /// if (tensor[ivs] != 0) 243 /// ind = ivs 244 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, 245 Value tensor, Value ind, ValueRange ivs) { 246 Value val = builder.create<tensor::ExtractOp>(loc, tensor, ivs); 247 Value cond = genIsNonzero(builder, loc, val); 248 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else*/ false); 249 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 250 unsigned i = 0; 251 for (auto iv : ivs) { 252 Value idx = constantIndex(builder, loc, i++); 253 builder.create<memref::StoreOp>(loc, iv, ind, idx); 254 } 255 return val; 256 } 257 258 /// Generates a call to release/delete a `SparseTensorCOO`. 259 static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp, 260 Value coo) { 261 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 262 TypeRange noTp; 263 createFuncCall(builder, op, name, noTp, coo, EmitCInterface::Off); 264 } 265 266 /// Generates a call that adds one element to a coordinate scheme. 267 /// In particular, this generates code like the following: 268 /// val = a[i1,..,ik]; 269 /// if val != 0 270 /// t->add(&val, [i1,..,ik], [p1,..,pk]); 271 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType, 272 Value ptr, Value valPtr, Value ind, Value perm) { 273 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 274 SmallVector<Value, 4> params{ptr, valPtr, ind, perm}; 275 Type pTp = getOpaquePointerType(builder); 276 createFuncCall(builder, op, name, pTp, params, EmitCInterface::On); 277 } 278 279 /// Generates a call to `iter->getNext()`. If there is a next element, 280 /// then it is copied into the out-parameters `ind` and `elemPtr`, 281 /// and the return value is true. If there isn't a next element, then 282 /// the memory for `iter` is freed and the return value is false. 283 static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter, 284 Value ind, Value elemPtr) { 285 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 286 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 287 SmallVector<Value, 3> params{iter, ind, elemPtr}; 288 Type i1 = builder.getI1Type(); 289 return createFuncCall(builder, op, name, i1, params, EmitCInterface::On) 290 .getResult(0); 291 } 292 293 /// If the tensor is a sparse constant, generates and returns the pair of 294 /// the constants for the indices and the values. 295 static Optional<std::pair<Value, Value>> 296 genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { 297 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 298 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 299 DenseElementsAttr indicesAttr = attr.getIndices(); 300 Value indices = builder.create<arith::ConstantOp>(loc, indicesAttr); 301 DenseElementsAttr valuesAttr = attr.getValues(); 302 Value values = builder.create<arith::ConstantOp>(loc, valuesAttr); 303 return std::make_pair(indices, values); 304 } 305 } 306 return {}; 307 } 308 309 /// Generates the code to copy the index at indices[ivs] to ind, and return 310 /// the value at value[ivs]. 311 static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc, 312 Value indices, Value values, Value ind, 313 ValueRange ivs, unsigned rank) { 314 for (unsigned i = 0; i < rank; i++) { 315 Value idx = constantIndex(builder, loc, i); 316 Value val = builder.create<tensor::ExtractOp>(loc, indices, 317 ValueRange{ivs[0], idx}); 318 val = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), val); 319 builder.create<memref::StoreOp>(loc, val, ind, idx); 320 } 321 return builder.create<tensor::ExtractOp>(loc, values, ivs[0]); 322 } 323 324 /// Generates code to allocate a buffer of the given type, and zero 325 /// initialize it. If the buffer type has any dynamic sizes, then the 326 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 327 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 328 static Value allocDenseTensor(OpBuilder &builder, Location loc, 329 RankedTensorType tensorTp, ValueRange sizes) { 330 Type elemTp = tensorTp.getElementType(); 331 auto shape = tensorTp.getShape(); 332 auto memTp = MemRefType::get(shape, elemTp); 333 SmallVector<Value> dynamicSizes; 334 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 335 if (shape[i] == ShapedType::kDynamicSize) 336 dynamicSizes.push_back(sizes[i]); 337 } 338 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 339 Value zero = constantZero(builder, loc, elemTp); 340 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 341 return mem; 342 } 343 344 /// Generates code to deallocate a dense buffer. 345 static void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer) { 346 builder.create<memref::DeallocOp>(loc, buffer); 347 } 348 349 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 350 /// the tensor created by allocDenseTensor(). The `rank` is the rank 351 /// of the `tensor` and the length of `ind`. 352 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, 353 Value elemPtr, Value tensor, 354 unsigned rank, Value ind) { 355 SmallVector<Value, 4> ivs; 356 ivs.reserve(rank); 357 for (unsigned i = 0; i < rank; i++) { 358 Value idx = constantIndex(builder, loc, i); 359 ivs.push_back(builder.create<memref::LoadOp>(loc, ind, idx)); 360 } 361 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr); 362 builder.create<memref::StoreOp>(loc, elemV, tensor, ivs); 363 } 364 365 /// Determine if the runtime library supports direct conversion to the 366 /// given target `dimTypes`. 367 static bool canUseDirectConversion( 368 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) { 369 bool alreadyCompressed = false; 370 for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { 371 switch (dimTypes[r]) { 372 case SparseTensorEncodingAttr::DimLevelType::Compressed: 373 if (alreadyCompressed) 374 return false; // Multiple compressed dimensions not yet supported. 375 alreadyCompressed = true; 376 break; 377 case SparseTensorEncodingAttr::DimLevelType::Dense: 378 if (alreadyCompressed) 379 return false; // Dense after Compressed not yet supported. 380 break; 381 case SparseTensorEncodingAttr::DimLevelType::Singleton: 382 // Although Singleton isn't generally supported yet, the direct 383 // conversion method doesn't have any particular problems with 384 // singleton after compressed. 385 break; 386 } 387 } 388 return true; 389 } 390 391 /// Helper method to translate indices during a reshaping operation. 392 /// TODO: provide as general utility to MLIR at large? 393 static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, 394 ArrayRef<ReassociationIndices> reassociation, 395 TensorType dstTp, TensorType srcTp, Value dstIdx, 396 Value srcIdx) { 397 unsigned dstRank = dstTp.getRank(); 398 unsigned srcRank = srcTp.getRank(); 399 unsigned start = 0; 400 unsigned i = 0; 401 bool isExpand = srcRank > dstRank; 402 ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape(); 403 // Iterate over reassociation map. 404 for (const auto &map : llvm::enumerate(reassociation)) { 405 // Prepare strides information in dimension slice. 406 uint64_t linear = 1; 407 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 408 assert(!ShapedType::isDynamic(shape[j])); 409 linear *= shape[j]; 410 } 411 // Start collapse. 412 Value idx = constantIndex(rewriter, loc, i++); 413 Value val; 414 if (!isExpand) 415 val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx); 416 // Iterate over dimension slice. 417 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 418 linear /= shape[j]; 419 Value stride = constantIndex(rewriter, loc, linear); 420 Value jdx = constantIndex(rewriter, loc, j); 421 if (isExpand) { 422 Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx); 423 Value mul = linear == 1 424 ? old 425 : rewriter.create<arith::MulIOp>(loc, old, stride); 426 val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul; 427 } else { 428 Value old = val; 429 if (linear != 1) 430 val = rewriter.create<arith::DivUIOp>(loc, val, stride); 431 rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx); 432 if (linear != 1) 433 val = rewriter.create<arith::RemUIOp>(loc, old, stride); 434 } 435 } 436 // Finalize expansion. 437 if (isExpand) 438 rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx); 439 start += map.value().size(); 440 } 441 // Sanity. 442 assert((isExpand && i == dstRank) || (!isExpand && i == srcRank)); 443 } 444 445 /// Generate code for a general sparse to sparse reshaping operation. 446 /// Note that unlike dense reshaping (which can be done with a "cheap" 447 /// change of view), sparse reshaping is currently done with actual 448 /// data shuffling. 449 /// 450 /// TODO: proportional to nnz, but still a lot of data movement 451 /// https://github.com/llvm/llvm-project/issues/56477 452 /// 453 /// iter = src->toCOO(); 454 /// coo = newSparseCOO() 455 /// while (elem = iter->getNext()) { 456 /// coo->add(reshape(elem.indices), elem.value) 457 /// } 458 /// s = newSparseTensor(coo) 459 static LogicalResult 460 genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter, 461 ArrayRef<ReassociationIndices> reassociation, Value src, 462 RankedTensorType dstTp, RankedTensorType srcTp) { 463 Location loc = op->getLoc(); 464 auto encDst = getSparseTensorEncoding(dstTp); 465 auto encSrc = getSparseTensorEncoding(srcTp); 466 assert(encDst && encSrc); 467 unsigned srcRank = srcTp.getRank(); 468 unsigned dstRank = dstTp.getRank(); 469 Type elemTp = srcTp.getElementType(); 470 assert(elemTp == dstTp.getElementType() && 471 "reshape should not change element type"); 472 // Start an iterator over the source tensor (in original index order). 473 auto noPerm = SparseTensorEncodingAttr::get( 474 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 475 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 476 SmallVector<Value, 4> sizes; 477 SmallVector<Value, 8> params; 478 sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src); 479 newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes, 480 src); 481 Value iter = genNewCall(rewriter, op, params); 482 // Start a new COO for the destination tensor. 483 sizes.clear(); 484 params.clear(); 485 sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src); 486 newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes); 487 Value coo = genNewCall(rewriter, op, params); 488 Value dstPerm = params[2]; 489 // Construct a while loop over the iterator. 490 Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType()); 491 Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType()); 492 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 493 SmallVector<Value> noArgs; 494 SmallVector<Type> noTypes; 495 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 496 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 497 rewriter.setInsertionPointToEnd(before); 498 Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr); 499 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 500 // Translate indices from source to target and insert. Note that we do 501 // not need to store the value in elemPtr, as the value is still there. 502 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 503 rewriter.setInsertionPointToStart(after); 504 translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx); 505 genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm); 506 rewriter.create<scf::YieldOp>(loc); 507 // Final call to construct sparse tensor storage and free temporary resources. 508 rewriter.setInsertionPointAfter(whileOp); 509 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 510 params[7] = coo; 511 Value dst = genNewCall(rewriter, op, params); 512 genDelCOOCall(rewriter, op, elemTp, coo); 513 genDelCOOCall(rewriter, op, elemTp, iter); 514 rewriter.replaceOp(op, dst); 515 return success(); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // Conversion rules. 520 //===----------------------------------------------------------------------===// 521 522 /// Sparse conversion rule for returns. 523 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 524 public: 525 using OpConversionPattern::OpConversionPattern; 526 LogicalResult 527 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 528 ConversionPatternRewriter &rewriter) const override { 529 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 530 return success(); 531 } 532 }; 533 534 /// Sparse conversion rule for dimension accesses. 535 class SparseTensorToDimSizeConverter 536 : public OpConversionPattern<tensor::DimOp> { 537 public: 538 using OpConversionPattern::OpConversionPattern; 539 LogicalResult 540 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 541 ConversionPatternRewriter &rewriter) const override { 542 // Only rewrite annotated DimOp with constant index. 543 auto enc = getSparseTensorEncoding(op.getSource().getType()); 544 if (!enc) 545 return failure(); 546 Optional<int64_t> index = op.getConstantIndex(); 547 if (!index) 548 return failure(); 549 // Generate the call. 550 Value src = adaptor.getOperands()[0]; 551 int64_t idx = *index; 552 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 553 return success(); 554 } 555 }; 556 557 /// Sparse conversion rule for trivial tensor casts. 558 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 559 public: 560 using OpConversionPattern::OpConversionPattern; 561 LogicalResult 562 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 563 ConversionPatternRewriter &rewriter) const override { 564 // Only rewrite identically annotated source/dest. 565 auto encDst = getSparseTensorEncoding(op.getType()); 566 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 567 if (!encDst || encDst != encSrc) 568 return failure(); 569 rewriter.replaceOp(op, adaptor.getOperands()); 570 return success(); 571 } 572 }; 573 574 /// Sparse conversion rule for a reshape operator. 575 template <typename ReshapeOp> 576 class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> { 577 public: 578 using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor; 579 using OpConversionPattern<ReshapeOp>::OpConversionPattern; 580 LogicalResult 581 matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, 582 ConversionPatternRewriter &rewriter) const override { 583 Type dstType = op.getResult().getType(); 584 Type srcType = op.getSrc().getType(); 585 auto encDst = getSparseTensorEncoding(dstType); 586 auto encSrc = getSparseTensorEncoding(srcType); 587 if (encDst && encSrc) 588 return genSparse2SparseReshape( 589 op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0], 590 dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>()); 591 return failure(); // handled elsewhere 592 } 593 }; 594 595 /// Sparse conversion rule for the new operator. 596 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 597 public: 598 using OpConversionPattern::OpConversionPattern; 599 LogicalResult 600 matchAndRewrite(NewOp op, OpAdaptor adaptor, 601 ConversionPatternRewriter &rewriter) const override { 602 Type resType = op.getType(); 603 auto enc = getSparseTensorEncoding(resType); 604 if (!enc) 605 return failure(); 606 // Generate the call to construct tensor from ptr. The sizes are 607 // inferred from the result type of the new operator. 608 SmallVector<Value, 4> sizes; 609 SmallVector<Value, 8> params; 610 ShapedType stp = resType.cast<ShapedType>(); 611 sizesFromType(rewriter, sizes, op.getLoc(), stp); 612 Value ptr = adaptor.getOperands()[0]; 613 newParams(rewriter, params, op, stp, enc, Action::kFromFile, sizes, ptr); 614 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 615 return success(); 616 } 617 }; 618 619 /// Sparse conversion rule for the alloc operator. 620 class SparseTensorAllocConverter 621 : public OpConversionPattern<bufferization::AllocTensorOp> { 622 public: 623 using OpConversionPattern::OpConversionPattern; 624 LogicalResult 625 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 626 ConversionPatternRewriter &rewriter) const override { 627 if (op.getCopy()) 628 return rewriter.notifyMatchFailure(op, 629 "sparse tensor copy not implemented"); 630 RankedTensorType resType = op.getType(); 631 auto enc = getSparseTensorEncoding(resType); 632 if (!enc) 633 return failure(); 634 // Gather all dimension sizes as SSA values. 635 SmallVector<Value> sizes; 636 unsigned int operandCtr = 0; 637 for (int64_t i = 0; i < resType.getRank(); ++i) { 638 if (resType.isDynamicDim(i)) { 639 sizes.push_back(adaptor.getOperands()[operandCtr++]); 640 } else { 641 sizes.push_back(rewriter.create<arith::ConstantIndexOp>( 642 op.getLoc(), op.getStaticSize(i))); 643 } 644 } 645 // Generate the call to construct empty tensor. The sizes are 646 // explicitly defined by the arguments to the alloc operator. 647 SmallVector<Value, 8> params; 648 ShapedType stp = resType.cast<ShapedType>(); 649 newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes); 650 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 651 return success(); 652 } 653 }; 654 655 /// Sparse conversion rule for the convert operator. 656 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 657 public: 658 using OpConversionPattern::OpConversionPattern; 659 SparseTensorConvertConverter(MLIRContext *context, 660 SparseTensorConversionOptions o) 661 : OpConversionPattern<ConvertOp>(context), options(o) {} 662 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 663 SparseTensorConversionOptions o) 664 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 665 666 LogicalResult 667 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 668 ConversionPatternRewriter &rewriter) const override { 669 Location loc = op->getLoc(); 670 Type resType = op.getType(); 671 Type srcType = op.getSource().getType(); 672 auto encDst = getSparseTensorEncoding(resType); 673 auto encSrc = getSparseTensorEncoding(srcType); 674 Value src = adaptor.getOperands()[0]; 675 if (encDst && encSrc) { 676 // This is a sparse => sparse conversion, which is handled as follows: 677 // t = src->toCOO(); ; src to COO in dst order 678 // dst = newSparseTensor(t) 679 // Using the coordinate scheme as an intermediate does not always 680 // yield the fastest conversion but avoids the need for a full 681 // O(N^2) conversion matrix. 682 if (encDst == encSrc) { 683 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 684 return success(); 685 } 686 SmallVector<Value, 4> sizes; 687 SmallVector<Value, 8> params; 688 ShapedType stp = srcType.cast<ShapedType>(); 689 sizesFromPtr(rewriter, sizes, op, encSrc, stp, src); 690 bool useDirectConversion; 691 switch (options.sparseToSparseStrategy) { 692 case SparseToSparseConversionStrategy::kViaCOO: 693 useDirectConversion = false; 694 break; 695 case SparseToSparseConversionStrategy::kDirect: 696 useDirectConversion = true; 697 assert(canUseDirectConversion(encDst.getDimLevelType()) && 698 "Unsupported target for direct sparse-to-sparse conversion"); 699 break; 700 case SparseToSparseConversionStrategy::kAuto: 701 useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); 702 break; 703 } 704 if (useDirectConversion) { 705 newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse, 706 sizes, src); 707 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 708 } else { // use via-COO conversion. 709 // Set up encoding with right mix of src and dst so that the two 710 // method calls can share most parameters, while still providing 711 // the correct sparsity information to either of them. 712 auto enc = SparseTensorEncodingAttr::get( 713 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 714 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 715 newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src); 716 Value coo = genNewCall(rewriter, op, params); 717 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 718 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 719 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 720 params[7] = coo; 721 Value dst = genNewCall(rewriter, op, params); 722 genDelCOOCall(rewriter, op, stp.getElementType(), coo); 723 rewriter.replaceOp(op, dst); 724 } 725 return success(); 726 } 727 if (!encDst && encSrc) { 728 // This is sparse => dense conversion, which is handled as follows: 729 // dst = new Tensor(0); 730 // iter = src->toCOO(); 731 // iter->startIterator(); 732 // while (elem = iter->getNext()) { 733 // dst[elem.indices] = elem.value; 734 // } 735 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 736 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 737 unsigned rank = dstTensorTp.getRank(); 738 Type elemTp = dstTensorTp.getElementType(); 739 // Fabricate a no-permutation encoding for newParams(). 740 // The pointer/index types must be those of `src`. 741 // The dimLevelTypes aren't actually used by Action::kToIterator. 742 encDst = SparseTensorEncodingAttr::get( 743 op->getContext(), 744 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 745 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 746 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 747 SmallVector<Value, 4> sizes; 748 SmallVector<Value, 8> params; 749 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 750 newParams(rewriter, params, op, dstTensorTp, encDst, Action::kToIterator, 751 sizes, src); 752 Value iter = genNewCall(rewriter, op, params); 753 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 754 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 755 Block *insertionBlock = rewriter.getInsertionBlock(); 756 // TODO: Dense buffers should be allocated/deallocated via the callback 757 // in BufferizationOptions. 758 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 759 SmallVector<Value> noArgs; 760 SmallVector<Type> noTypes; 761 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 762 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 763 rewriter.setInsertionPointToEnd(before); 764 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 765 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 766 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 767 rewriter.setInsertionPointToStart(after); 768 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 769 rewriter.create<scf::YieldOp>(loc); 770 rewriter.setInsertionPointAfter(whileOp); 771 genDelCOOCall(rewriter, op, elemTp, iter); 772 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 773 // Deallocate the buffer. 774 if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { 775 rewriter.setInsertionPoint(insertionBlock->getTerminator()); 776 deallocDenseTensor(rewriter, loc, dst); 777 } 778 return success(); 779 } 780 if (!encDst && !encSrc) { 781 // dense => dense 782 return failure(); 783 } 784 // This is a dense => sparse conversion or a sparse constant in COO => 785 // sparse conversion, which is handled as follows: 786 // t = newSparseCOO() 787 // ...code to fill the COO tensor t... 788 // s = newSparseTensor(t) 789 // 790 // To fill the COO tensor from a dense tensor: 791 // for i1 in dim1 792 // .. 793 // for ik in dimk 794 // val = a[i1,..,ik] 795 // if val != 0 796 // t->add(val, [i1,..,ik], [p1,..,pk]) 797 // 798 // To fill the COO tensor from a sparse constant in COO format: 799 // for i in range(NNZ) 800 // val = values[i] 801 // [i1,..,ik] = indices[i] 802 // t->add(val, [i1,..,ik], [p1,..,pk]) 803 // 804 // Note that the dense tensor traversal code is actually implemented 805 // using MLIR IR to avoid having to expose too much low-level 806 // memref traversal details to the runtime support library. 807 // Also note that the code below only generates the "new" ops and 808 // the loop-nest per se; whereas the entire body of the innermost 809 // loop is generated by genAddElt(). 810 ShapedType stp = resType.cast<ShapedType>(); 811 unsigned rank = stp.getRank(); 812 SmallVector<Value, 4> sizes; 813 SmallVector<Value, 8> params; 814 sizesFromSrc(rewriter, sizes, loc, src); 815 newParams(rewriter, params, op, stp, encDst, Action::kEmptyCOO, sizes); 816 Value coo = genNewCall(rewriter, op, params); 817 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 818 Value perm = params[2]; 819 SmallVector<Value> lo; 820 SmallVector<Value> hi; 821 SmallVector<Value> st; 822 Value zero = constantIndex(rewriter, loc, 0); 823 Value one = constantIndex(rewriter, loc, 1); 824 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 825 bool isCOOConstant = indicesValues.has_value(); 826 Value indices; 827 Value values; 828 if (isCOOConstant) { 829 indices = indicesValues->first; 830 values = indicesValues->second; 831 lo.push_back(zero); 832 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 833 st.push_back(one); 834 } else { 835 for (unsigned i = 0; i < rank; i++) { 836 lo.push_back(zero); 837 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 838 st.push_back(one); 839 } 840 } 841 Type eltType = stp.getElementType(); 842 Value elemPtr = genAllocaScalar(rewriter, loc, eltType); 843 scf::buildLoopNest( 844 rewriter, op.getLoc(), lo, hi, st, {}, 845 [&](OpBuilder &builder, Location loc, ValueRange ivs, 846 ValueRange args) -> scf::ValueVector { 847 Value val; 848 if (isCOOConstant) 849 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 850 ivs, rank); 851 else 852 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 853 builder.create<memref::StoreOp>(loc, val, elemPtr); 854 genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm); 855 return {}; 856 }); 857 // Final call to construct sparse tensor storage. 858 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 859 params[7] = coo; 860 Value dst = genNewCall(rewriter, op, params); 861 genDelCOOCall(rewriter, op, eltType, coo); 862 rewriter.replaceOp(op, dst); 863 return success(); 864 } 865 866 private: 867 /// Options to control sparse code generation. 868 SparseTensorConversionOptions options; 869 }; 870 871 /// Sparse conversion rule for the dealloc operator. 872 class SparseTensorDeallocConverter 873 : public OpConversionPattern<bufferization::DeallocTensorOp> { 874 public: 875 using OpConversionPattern::OpConversionPattern; 876 LogicalResult 877 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 878 ConversionPatternRewriter &rewriter) const override { 879 auto enc = getSparseTensorEncoding(op.getTensor().getType()); 880 if (!enc) 881 return failure(); 882 StringRef name = "delSparseTensor"; 883 TypeRange noTp; 884 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 885 EmitCInterface::Off); 886 rewriter.eraseOp(op); 887 return success(); 888 } 889 }; 890 891 /// Sparse conversion rule for pointer accesses. 892 class SparseTensorToPointersConverter 893 : public OpConversionPattern<ToPointersOp> { 894 public: 895 using OpConversionPattern::OpConversionPattern; 896 LogicalResult 897 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 898 ConversionPatternRewriter &rewriter) const override { 899 Type resType = op.getType(); 900 Type ptrType = resType.cast<ShapedType>().getElementType(); 901 SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)}; 902 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 903 EmitCInterface::On); 904 return success(); 905 } 906 }; 907 908 /// Sparse conversion rule for index accesses. 909 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 910 public: 911 using OpConversionPattern::OpConversionPattern; 912 LogicalResult 913 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 914 ConversionPatternRewriter &rewriter) const override { 915 Type resType = op.getType(); 916 Type indType = resType.cast<ShapedType>().getElementType(); 917 SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)}; 918 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 919 EmitCInterface::On); 920 return success(); 921 } 922 }; 923 924 /// Sparse conversion rule for value accesses. 925 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 926 public: 927 using OpConversionPattern::OpConversionPattern; 928 LogicalResult 929 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 930 ConversionPatternRewriter &rewriter) const override { 931 Type resType = op.getType(); 932 Type eltType = resType.cast<ShapedType>().getElementType(); 933 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)}; 934 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 935 EmitCInterface::On); 936 return success(); 937 } 938 }; 939 940 /// Sparse conversion rule for tensor rematerialization. 941 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 942 public: 943 using OpConversionPattern::OpConversionPattern; 944 LogicalResult 945 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 946 ConversionPatternRewriter &rewriter) const override { 947 if (op.getHasInserts()) { 948 // Finalize any pending insertions. 949 StringRef name = "endInsert"; 950 TypeRange noTp; 951 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 952 EmitCInterface::Off); 953 } 954 rewriter.replaceOp(op, adaptor.getOperands()); 955 return success(); 956 } 957 }; 958 959 /// Sparse conversion rule for inserting in lexicographic index order. 960 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 961 public: 962 using OpConversionPattern::OpConversionPattern; 963 LogicalResult 964 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 965 ConversionPatternRewriter &rewriter) const override { 966 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 967 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 968 TypeRange noTp; 969 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 970 EmitCInterface::On); 971 return success(); 972 } 973 }; 974 975 /// Sparse conversion rule for the expand operator. 976 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 977 public: 978 using OpConversionPattern::OpConversionPattern; 979 LogicalResult 980 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 981 ConversionPatternRewriter &rewriter) const override { 982 Location loc = op->getLoc(); 983 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 984 Type eltType = srcType.getElementType(); 985 Type boolType = rewriter.getIntegerType(1); 986 Type idxType = rewriter.getIndexType(); 987 // All initialization should be done on entry of the loop nest. 988 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 989 // Determine the size for access expansion. 990 auto enc = getSparseTensorEncoding(srcType); 991 Value src = adaptor.getOperands()[0]; 992 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 993 // Allocate temporary buffers for values, filled-switch, and indices. 994 // We do not use stack buffers for this, since the expanded size may 995 // be rather large (as it envelops a single expanded dense dimension). 996 Value values = genAlloc(rewriter, loc, sz, eltType); 997 Value filled = genAlloc(rewriter, loc, sz, boolType); 998 Value indices = genAlloc(rewriter, loc, sz, idxType); 999 Value zero = constantZero(rewriter, loc, idxType); 1000 // Reset the values/filled-switch to all-zero/false. Note that this 1001 // introduces an O(N) operation into the computation, but this reset 1002 // operation is amortized over the innermost loops for the access 1003 // pattern expansion. As noted in the operation doc, we would like 1004 // to amortize this setup cost even between kernels. 1005 rewriter.create<linalg::FillOp>( 1006 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 1007 ValueRange{values}); 1008 rewriter.create<linalg::FillOp>( 1009 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 1010 ValueRange{filled}); 1011 // Replace expansion op with these buffers and initial index. 1012 assert(op.getNumResults() == 4); 1013 rewriter.replaceOp(op, {values, filled, indices, zero}); 1014 return success(); 1015 } 1016 }; 1017 1018 /// Sparse conversion rule for the compress operator. 1019 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 1020 public: 1021 using OpConversionPattern::OpConversionPattern; 1022 LogicalResult 1023 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 1024 ConversionPatternRewriter &rewriter) const override { 1025 Location loc = op->getLoc(); 1026 // Note that this method call resets the values/filled-switch back to 1027 // all-zero/false by only iterating over the set elements, so the 1028 // complexity remains proportional to the sparsity of the expanded 1029 // access pattern. 1030 Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType(); 1031 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 1032 TypeRange noTp; 1033 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 1034 EmitCInterface::On); 1035 // Deallocate the buffers on exit of the loop nest. 1036 Operation *parent = op; 1037 for (; isa<scf::ForOp>(parent->getParentOp()) || 1038 isa<scf::WhileOp>(parent->getParentOp()) || 1039 isa<scf::ParallelOp>(parent->getParentOp()) || 1040 isa<scf::IfOp>(parent->getParentOp()); 1041 parent = parent->getParentOp()) 1042 ; 1043 rewriter.setInsertionPointAfter(parent); 1044 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]); 1045 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]); 1046 rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]); 1047 return success(); 1048 } 1049 }; 1050 1051 /// Sparse conversion rule for the output operator. 1052 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 1053 public: 1054 using OpConversionPattern::OpConversionPattern; 1055 LogicalResult 1056 matchAndRewrite(OutOp op, OpAdaptor adaptor, 1057 ConversionPatternRewriter &rewriter) const override { 1058 Location loc = op->getLoc(); 1059 ShapedType srcType = op.getTensor().getType().cast<ShapedType>(); 1060 // Convert to default permuted COO. 1061 Value src = adaptor.getOperands()[0]; 1062 auto encSrc = getSparseTensorEncoding(srcType); 1063 SmallVector<Value, 4> sizes; 1064 SmallVector<Value, 8> params; 1065 sizesFromPtr(rewriter, sizes, op, encSrc, srcType, src); 1066 auto enc = SparseTensorEncodingAttr::get( 1067 op->getContext(), encSrc.getDimLevelType(), AffineMap(), 1068 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 1069 newParams(rewriter, params, op, srcType, enc, Action::kToCOO, sizes, src); 1070 Value coo = genNewCall(rewriter, op, params); 1071 // Then output the tensor to external file with indices in the externally 1072 // visible lexicographic index order. A sort is required if the source was 1073 // not in that order yet (note that the sort can be dropped altogether if 1074 // external format does not care about the order at all, but here we assume 1075 // it does). 1076 bool sort = 1077 encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); 1078 params.clear(); 1079 params.push_back(coo); 1080 params.push_back(adaptor.getOperands()[1]); 1081 params.push_back(constantI1(rewriter, loc, sort)); 1082 Type eltType = srcType.getElementType(); 1083 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; 1084 TypeRange noTp; 1085 createFuncCall(rewriter, op, name, noTp, params, EmitCInterface::Off); 1086 genDelCOOCall(rewriter, op, eltType, coo); 1087 rewriter.eraseOp(op); 1088 return success(); 1089 } 1090 }; 1091 1092 } // namespace 1093 1094 //===----------------------------------------------------------------------===// 1095 // Public method for populating conversion rules. 1096 //===----------------------------------------------------------------------===// 1097 1098 /// Populates the given patterns list with conversion rules required for 1099 /// the sparsification of linear algebra operations. 1100 void mlir::populateSparseTensorConversionPatterns( 1101 TypeConverter &typeConverter, RewritePatternSet &patterns, 1102 const SparseTensorConversionOptions &options) { 1103 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 1104 SparseCastConverter, SparseTensorNewConverter, 1105 SparseReshapeConverter<tensor::ExpandShapeOp>, 1106 SparseReshapeConverter<tensor::CollapseShapeOp>, 1107 SparseTensorAllocConverter, SparseTensorDeallocConverter, 1108 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 1109 SparseTensorToValuesConverter, SparseTensorLoadConverter, 1110 SparseTensorLexInsertConverter, SparseTensorExpandConverter, 1111 SparseTensorCompressConverter, SparseTensorOutConverter>( 1112 typeConverter, patterns.getContext()); 1113 patterns.add<SparseTensorConvertConverter>(typeConverter, 1114 patterns.getContext(), options); 1115 } 1116