1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 namespace { 33 34 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 35 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 36 enum class EmitCInterface : bool { Off = false, On = true }; 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Generates a constant zero of the given type. 43 inline static Value constantZero(ConversionPatternRewriter &rewriter, 44 Location loc, Type t) { 45 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 46 } 47 48 /// Generates a constant of `index` type. 49 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 50 Location loc, int64_t i) { 51 return rewriter.create<arith::ConstantIndexOp>(loc, i); 52 } 53 54 /// Generates a constant of `i32` type. 55 inline static Value constantI32(ConversionPatternRewriter &rewriter, 56 Location loc, int32_t i) { 57 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 58 } 59 60 /// Generates a constant of `i8` type. 61 inline static Value constantI8(ConversionPatternRewriter &rewriter, 62 Location loc, int8_t i) { 63 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 64 } 65 66 /// Generates a constant of the given `Action`. 67 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 68 Action action) { 69 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 70 } 71 72 /// Generates a constant of the internal type encoding for overhead storage. 73 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 74 Location loc, unsigned width) { 75 OverheadType sec; 76 switch (width) { 77 default: 78 sec = OverheadType::kU64; 79 break; 80 case 32: 81 sec = OverheadType::kU32; 82 break; 83 case 16: 84 sec = OverheadType::kU16; 85 break; 86 case 8: 87 sec = OverheadType::kU8; 88 break; 89 } 90 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 91 } 92 93 /// Generates a constant of the internal type encoding for pointer 94 /// overhead storage. 95 static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, 96 Location loc, 97 SparseTensorEncodingAttr &enc) { 98 return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); 99 } 100 101 /// Generates a constant of the internal type encoding for index overhead 102 /// storage. 103 static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, 104 Location loc, 105 SparseTensorEncodingAttr &enc) { 106 return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); 107 } 108 109 /// Generates a constant of the internal type encoding for primary storage. 110 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 111 Location loc, Type tp) { 112 PrimaryType primary; 113 if (tp.isF64()) 114 primary = PrimaryType::kF64; 115 else if (tp.isF32()) 116 primary = PrimaryType::kF32; 117 else if (tp.isInteger(64)) 118 primary = PrimaryType::kI64; 119 else if (tp.isInteger(32)) 120 primary = PrimaryType::kI32; 121 else if (tp.isInteger(16)) 122 primary = PrimaryType::kI16; 123 else if (tp.isInteger(8)) 124 primary = PrimaryType::kI8; 125 else 126 llvm_unreachable("Unknown element type"); 127 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 128 } 129 130 /// Generates a constant of the internal dimension level type encoding. 131 static Value 132 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 133 SparseTensorEncodingAttr::DimLevelType dlt) { 134 DimLevelType dlt2; 135 switch (dlt) { 136 case SparseTensorEncodingAttr::DimLevelType::Dense: 137 dlt2 = DimLevelType::kDense; 138 break; 139 case SparseTensorEncodingAttr::DimLevelType::Compressed: 140 dlt2 = DimLevelType::kCompressed; 141 break; 142 case SparseTensorEncodingAttr::DimLevelType::Singleton: 143 dlt2 = DimLevelType::kSingleton; 144 break; 145 } 146 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 147 } 148 149 /// Returns the equivalent of `void*` for opaque arguments to the 150 /// execution engine. 151 static Type getOpaquePointerType(PatternRewriter &rewriter) { 152 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 153 } 154 155 /// Returns a function reference (first hit also inserts into module). Sets 156 /// the "_emit_c_interface" on the function declaration when requested, 157 /// so that LLVM lowering generates a wrapper function that takes care 158 /// of ABI complications with passing in and returning MemRefs to C functions. 159 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 160 TypeRange resultType, ValueRange operands, 161 EmitCInterface emitCInterface) { 162 MLIRContext *context = op->getContext(); 163 auto module = op->getParentOfType<ModuleOp>(); 164 auto result = SymbolRefAttr::get(context, name); 165 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 166 if (!func) { 167 OpBuilder moduleBuilder(module.getBodyRegion()); 168 func = moduleBuilder.create<FuncOp>( 169 op->getLoc(), name, 170 FunctionType::get(context, operands.getTypes(), resultType)); 171 func.setPrivate(); 172 if (static_cast<bool>(emitCInterface)) 173 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 174 } 175 return result; 176 } 177 178 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 179 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 180 TypeRange resultType, ValueRange operands, 181 EmitCInterface emitCInterface) { 182 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 183 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 184 } 185 186 /// Replaces the `op` with a `CallOp` to the function reference returned 187 /// by `getFunc()`. 188 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 189 StringRef name, TypeRange resultType, 190 ValueRange operands, 191 EmitCInterface emitCInterface) { 192 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 193 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 194 } 195 196 /// Generates dimension size call. 197 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 198 SparseTensorEncodingAttr &enc, Value src, 199 int64_t idx) { 200 // Permute the index according to an optional dimension ordering. 201 if (AffineMap p = enc.getDimOrdering()) 202 idx = p.getPermutedPosition(idx); 203 // Generate the call. 204 StringRef name = "sparseDimSize"; 205 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 206 Type iTp = rewriter.getIndexType(); 207 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 208 .getResult(0); 209 } 210 211 /// Generates a call into the "swiss army knife" method of the sparse runtime 212 /// support library for materializing sparse tensors into the computation. 213 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 214 ArrayRef<Value> params) { 215 StringRef name = "newSparseTensor"; 216 Type pTp = getOpaquePointerType(rewriter); 217 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 218 .getResult(0); 219 } 220 221 /// Populates given sizes array from type. 222 static void sizesFromType(ConversionPatternRewriter &rewriter, 223 SmallVector<Value, 4> &sizes, Location loc, 224 ShapedType stp) { 225 auto shape = stp.getShape(); 226 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 227 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 228 sizes.push_back(constantIndex(rewriter, loc, s)); 229 } 230 } 231 232 /// Populates given sizes array from source. 233 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 234 SmallVector<Value, 4> &sizes, Location loc, 235 Value src) { 236 unsigned rank = src.getType().cast<ShapedType>().getRank(); 237 for (unsigned i = 0; i < rank; i++) 238 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 239 } 240 241 /// Populates given sizes array from type (for static sizes) and from 242 /// an already converted into opague pointer source (for dynamic sizes). 243 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 244 SmallVector<Value, 4> &sizes, Operation *op, 245 SparseTensorEncodingAttr &enc, ShapedType stp, 246 Value src) { 247 Location loc = op->getLoc(); 248 auto shape = stp.getShape(); 249 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 250 if (shape[i] == ShapedType::kDynamicSize) 251 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 252 else 253 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 254 } 255 256 /// Generates an uninitialized temporary buffer of the given size and 257 /// type, but returns it as type `memref<? x $tp>` (rather than as type 258 /// `memref<$sz x $tp>`). 259 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 260 Value sz, Type tp) { 261 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 262 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 263 } 264 265 /// Generates an uninitialized temporary buffer of the given size and 266 /// type, but returns it as type `memref<? x $tp>` (rather than as type 267 /// `memref<$sz x $tp>`). 268 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 269 unsigned sz, Type tp) { 270 return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); 271 } 272 273 /// Generates an uninitialized temporary buffer with room for one value 274 /// of the given type, and returns the `memref<$tp>`. 275 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 276 Type tp) { 277 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 278 } 279 280 /// Generates a temporary buffer of the given type and given contents. 281 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 282 ArrayRef<Value> values) { 283 unsigned sz = values.size(); 284 assert(sz >= 1); 285 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 286 for (unsigned i = 0; i < sz; i++) { 287 Value idx = constantIndex(rewriter, loc, i); 288 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 289 } 290 return buffer; 291 } 292 293 /// Populates parameters required to call the "swiss army knife" method of the 294 /// sparse runtime support library for materializing sparse tensors into the 295 /// computation. 296 static void newParams(ConversionPatternRewriter &rewriter, 297 SmallVector<Value, 8> ¶ms, Operation *op, 298 SparseTensorEncodingAttr &enc, Action action, 299 ValueRange szs, Value ptr = Value()) { 300 Location loc = op->getLoc(); 301 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 302 unsigned sz = dlt.size(); 303 // Sparsity annotations. 304 SmallVector<Value, 4> attrs; 305 for (unsigned i = 0; i < sz; i++) 306 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 307 params.push_back(genBuffer(rewriter, loc, attrs)); 308 // Dimension sizes array of the enveloping tensor. Useful for either 309 // verification of external data, or for construction of internal data. 310 SmallVector<Value, 4> sizes; 311 for (Value s : szs) 312 sizes.push_back(s); 313 params.push_back(genBuffer(rewriter, loc, sizes)); 314 // Dimension order permutation array. This is the "identity" permutation by 315 // default, or otherwise the "reverse" permutation of a given ordering, so 316 // that indices can be mapped quickly to the right position. 317 SmallVector<Value, 4> rev(sz); 318 if (AffineMap p = enc.getDimOrdering()) { 319 for (unsigned i = 0; i < sz; i++) 320 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 321 } else { 322 for (unsigned i = 0; i < sz; i++) 323 rev[i] = constantIndex(rewriter, loc, i); 324 } 325 params.push_back(genBuffer(rewriter, loc, rev)); 326 // Secondary and primary types encoding. 327 Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType(); 328 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 329 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 330 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 331 // User action. 332 params.push_back(constantAction(rewriter, loc, action)); 333 // Payload pointer. 334 if (!ptr) 335 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 336 params.push_back(ptr); 337 } 338 339 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 340 /// For floating types, we use the "unordered" comparator (i.e., returns 341 /// true if `v` is NaN). 342 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 343 Value v) { 344 Type t = v.getType(); 345 Value zero = constantZero(rewriter, loc, t); 346 if (t.isa<FloatType>()) 347 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 348 zero); 349 if (t.isIntOrIndex()) 350 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 351 zero); 352 llvm_unreachable("Unknown element type"); 353 } 354 355 /// Generates the code to read the value from tensor[ivs], and conditionally 356 /// stores the indices ivs to the memory in ind. The generated code looks like 357 /// the following and the insertion point after this routine is inside the 358 /// if-then branch behind the assignment to ind. This is to ensure that the 359 /// addEltX call generated after is inside the if-then branch. 360 /// if (tensor[ivs]!=0) { 361 /// ind = ivs 362 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 363 Location loc, Value tensor, Value ind, 364 ValueRange ivs) { 365 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 366 Value cond = genIsNonzero(rewriter, loc, val); 367 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 368 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 369 unsigned i = 0; 370 for (auto iv : ivs) { 371 Value idx = constantIndex(rewriter, loc, i++); 372 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 373 } 374 return val; 375 } 376 377 /// Generates a call that adds one element to a coordinate scheme. 378 /// In particular, this generates code like the following: 379 /// val = a[i1,..,ik]; 380 /// if val != 0 381 /// t->add(val, [i1,..,ik], [p1,..,pk]); 382 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 383 Type eltType, Value ptr, Value val, Value ind, 384 Value perm) { 385 StringRef name; 386 if (eltType.isF64()) 387 name = "addEltF64"; 388 else if (eltType.isF32()) 389 name = "addEltF32"; 390 else if (eltType.isInteger(64)) 391 name = "addEltI64"; 392 else if (eltType.isInteger(32)) 393 name = "addEltI32"; 394 else if (eltType.isInteger(16)) 395 name = "addEltI16"; 396 else if (eltType.isInteger(8)) 397 name = "addEltI8"; 398 else 399 llvm_unreachable("Unknown element type"); 400 SmallVector<Value, 4> params{ptr, val, ind, perm}; 401 Type pTp = getOpaquePointerType(rewriter); 402 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 403 } 404 405 /// Generates a call to `iter->getNext()`. If there is a next element, 406 /// then it is copied into the out-parameters `ind` and `elemPtr`, 407 /// and the return value is true. If there isn't a next element, then 408 /// the memory for `iter` is freed and the return value is false. 409 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 410 Value iter, Value ind, Value elemPtr) { 411 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 412 StringRef name; 413 if (elemTp.isF64()) 414 name = "getNextF64"; 415 else if (elemTp.isF32()) 416 name = "getNextF32"; 417 else if (elemTp.isInteger(64)) 418 name = "getNextI64"; 419 else if (elemTp.isInteger(32)) 420 name = "getNextI32"; 421 else if (elemTp.isInteger(16)) 422 name = "getNextI16"; 423 else if (elemTp.isInteger(8)) 424 name = "getNextI8"; 425 else 426 llvm_unreachable("Unknown element type"); 427 SmallVector<Value, 3> params{iter, ind, elemPtr}; 428 Type i1 = rewriter.getI1Type(); 429 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 430 .getResult(0); 431 } 432 433 /// If the tensor is a sparse constant, generates and returns the pair of 434 /// the constants for the indices and the values. 435 static Optional<std::pair<Value, Value>> 436 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 437 Value tensor) { 438 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 439 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 440 DenseElementsAttr indicesAttr = attr.getIndices(); 441 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 442 DenseElementsAttr valuesAttr = attr.getValues(); 443 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 444 return std::make_pair(indices, values); 445 } 446 } 447 return {}; 448 } 449 450 /// Generates the code to copy the index at indices[ivs] to ind, and return 451 /// the value at value[ivs]. 452 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 453 Location loc, Value indices, 454 Value values, Value ind, ValueRange ivs, 455 unsigned rank) { 456 for (unsigned i = 0; i < rank; i++) { 457 Value idx = constantIndex(rewriter, loc, i); 458 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 459 ValueRange{ivs[0], idx}); 460 val = 461 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 462 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 463 } 464 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 465 } 466 467 /// Generates code to allocate a tensor of the given type, and zero 468 /// initialize it. If the tensor type has any dynamic sizes, then the 469 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 470 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 471 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 472 RankedTensorType tensorTp, ValueRange sizes) { 473 Type elemTp = tensorTp.getElementType(); 474 auto shape = tensorTp.getShape(); 475 auto memTp = MemRefType::get(shape, elemTp); 476 SmallVector<Value> dynamicSizes; 477 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 478 if (shape[i] == ShapedType::kDynamicSize) 479 dynamicSizes.push_back(sizes[i]); 480 } 481 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 482 Value zero = constantZero(rewriter, loc, elemTp); 483 rewriter.create<linalg::FillOp>(loc, zero, mem); 484 return mem; 485 } 486 487 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 488 /// the tensor created by allocDenseTensor(). The `rank` is the rank 489 /// of the `tensor` and the length of `ind`. 490 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 491 Location loc, Value elemPtr, 492 Value tensor, unsigned rank, 493 Value ind) { 494 SmallVector<Value, 4> ivs; 495 ivs.reserve(rank); 496 for (unsigned i = 0; i < rank; i++) { 497 Value idx = constantIndex(rewriter, loc, i); 498 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 499 } 500 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 501 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // Conversion rules. 506 //===----------------------------------------------------------------------===// 507 508 /// Sparse conversion rule for returns. 509 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 510 public: 511 using OpConversionPattern::OpConversionPattern; 512 LogicalResult 513 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 514 ConversionPatternRewriter &rewriter) const override { 515 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 516 return success(); 517 } 518 }; 519 520 /// Sparse conversion rule for dimension accesses. 521 class SparseTensorToDimSizeConverter 522 : public OpConversionPattern<tensor::DimOp> { 523 public: 524 using OpConversionPattern::OpConversionPattern; 525 LogicalResult 526 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 527 ConversionPatternRewriter &rewriter) const override { 528 // Only rewrite annotated DimOp with constant index. 529 auto enc = getSparseTensorEncoding(op.source().getType()); 530 if (!enc) 531 return failure(); 532 Optional<int64_t> index = op.getConstantIndex(); 533 if (!index.hasValue()) 534 return failure(); 535 // Generate the call. 536 Value src = adaptor.getOperands()[0]; 537 int64_t idx = index.getValue(); 538 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 539 return success(); 540 } 541 }; 542 543 /// Sparse conversion rule for trivial tensor casts. 544 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 545 using OpConversionPattern::OpConversionPattern; 546 LogicalResult 547 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 548 ConversionPatternRewriter &rewriter) const override { 549 // Only rewrite identically annotated source/dest. 550 auto encDst = getSparseTensorEncoding(op.getType()); 551 auto encSrc = getSparseTensorEncoding(op.source().getType()); 552 if (!encDst || encDst != encSrc) 553 return failure(); 554 rewriter.replaceOp(op, adaptor.getOperands()); 555 return success(); 556 } 557 }; 558 559 /// Sparse conversion rule for the new operator. 560 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 561 using OpConversionPattern::OpConversionPattern; 562 LogicalResult 563 matchAndRewrite(NewOp op, OpAdaptor adaptor, 564 ConversionPatternRewriter &rewriter) const override { 565 Type resType = op.getType(); 566 auto enc = getSparseTensorEncoding(resType); 567 if (!enc) 568 return failure(); 569 // Generate the call to construct tensor from ptr. The sizes are 570 // inferred from the result type of the new operator. 571 SmallVector<Value, 4> sizes; 572 SmallVector<Value, 8> params; 573 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 574 Value ptr = adaptor.getOperands()[0]; 575 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 576 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 577 return success(); 578 } 579 }; 580 581 /// Sparse conversion rule for the init operator. 582 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 583 using OpConversionPattern::OpConversionPattern; 584 LogicalResult 585 matchAndRewrite(InitOp op, OpAdaptor adaptor, 586 ConversionPatternRewriter &rewriter) const override { 587 Type resType = op.getType(); 588 auto enc = getSparseTensorEncoding(resType); 589 if (!enc) 590 return failure(); 591 // Generate the call to construct empty tensor. The sizes are 592 // explicitly defined by the arguments to the init operator. 593 SmallVector<Value, 8> params; 594 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 595 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 596 return success(); 597 } 598 }; 599 600 /// Sparse conversion rule for the convert operator. 601 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 602 using OpConversionPattern::OpConversionPattern; 603 LogicalResult 604 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const override { 606 Location loc = op->getLoc(); 607 Type resType = op.getType(); 608 Type srcType = op.source().getType(); 609 auto encDst = getSparseTensorEncoding(resType); 610 auto encSrc = getSparseTensorEncoding(srcType); 611 Value src = adaptor.getOperands()[0]; 612 if (encDst && encSrc) { 613 // This is a sparse => sparse conversion, which is handled as follows: 614 // t = src->toCOO(); ; src to COO in dst order 615 // dst = newSparseTensor(t) 616 // Using the coordinate scheme as an intermediate does not always 617 // yield the fastest conversion but avoids the need for a full 618 // O(N^2) conversion matrix. 619 if (encDst == encSrc) { 620 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 621 return success(); 622 } 623 SmallVector<Value, 4> sizes; 624 SmallVector<Value, 8> params; 625 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 626 src); 627 // Set up encoding with right mix of src and dst so that the two 628 // method calls can share most parameters, while still providing 629 // the correct sparsity information to either of them. 630 auto enc = SparseTensorEncodingAttr::get( 631 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 632 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 633 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 634 Value coo = genNewCall(rewriter, op, params); 635 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 636 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 637 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 638 params[7] = coo; 639 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 640 return success(); 641 } 642 if (!encDst && encSrc) { 643 // This is sparse => dense conversion, which is handled as follows: 644 // dst = new Tensor(0); 645 // iter = src->toCOO(); 646 // iter->startIterator(); 647 // while (elem = iter->getNext()) { 648 // dst[elem.indices] = elem.value; 649 // } 650 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 651 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 652 unsigned rank = dstTensorTp.getRank(); 653 Type elemTp = dstTensorTp.getElementType(); 654 // Fabricate a no-permutation encoding for newParams(). 655 // The pointer/index types must be those of `src`. 656 // The dimLevelTypes aren't actually used by Action::kToIterator. 657 encDst = SparseTensorEncodingAttr::get( 658 op->getContext(), 659 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 660 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 661 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 662 SmallVector<Value, 4> sizes; 663 SmallVector<Value, 8> params; 664 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 665 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 666 Value iter = genNewCall(rewriter, op, params); 667 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 668 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 669 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 670 SmallVector<Value> noArgs; 671 SmallVector<Type> noTypes; 672 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 673 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 674 rewriter.setInsertionPointToEnd(before); 675 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 676 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 677 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 678 rewriter.setInsertionPointToStart(after); 679 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 680 rewriter.create<scf::YieldOp>(loc); 681 rewriter.setInsertionPointAfter(whileOp); 682 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 683 return success(); 684 } 685 if (!encDst && !encSrc) { 686 // dense => dense 687 return failure(); 688 } 689 // This is a dense => sparse conversion or a sparse constant in COO => 690 // sparse conversion, which is handled as follows: 691 // t = newSparseCOO() 692 // ...code to fill the COO tensor t... 693 // s = newSparseTensor(t) 694 // 695 // To fill the COO tensor from a dense tensor: 696 // for i1 in dim1 697 // .. 698 // for ik in dimk 699 // val = a[i1,..,ik] 700 // if val != 0 701 // t->add(val, [i1,..,ik], [p1,..,pk]) 702 // 703 // To fill the COO tensor from a sparse constant in COO format: 704 // for i in range(NNZ) 705 // val = values[i] 706 // [i1,..,ik] = indices[i] 707 // t->add(val, [i1,..,ik], [p1,..,pk]) 708 // 709 // Note that the dense tensor traversal code is actually implemented 710 // using MLIR IR to avoid having to expose too much low-level 711 // memref traversal details to the runtime support library. 712 // Also note that the code below only generates the "new" ops and 713 // the loop-nest per se; whereas the entire body of the innermost 714 // loop is generated by genAddElt(). 715 ShapedType stp = resType.cast<ShapedType>(); 716 unsigned rank = stp.getRank(); 717 SmallVector<Value, 4> sizes; 718 SmallVector<Value, 8> params; 719 sizesFromSrc(rewriter, sizes, loc, src); 720 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 721 Value ptr = genNewCall(rewriter, op, params); 722 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 723 Value perm = params[2]; 724 SmallVector<Value> lo; 725 SmallVector<Value> hi; 726 SmallVector<Value> st; 727 Value zero = constantIndex(rewriter, loc, 0); 728 Value one = constantIndex(rewriter, loc, 1); 729 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 730 bool isCOOConstant = indicesValues.hasValue(); 731 Value indices; 732 Value values; 733 if (isCOOConstant) { 734 indices = indicesValues->first; 735 values = indicesValues->second; 736 lo.push_back(zero); 737 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 738 st.push_back(one); 739 } else { 740 for (unsigned i = 0; i < rank; i++) { 741 lo.push_back(zero); 742 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 743 st.push_back(one); 744 } 745 } 746 Type eltType = stp.getElementType(); 747 scf::buildLoopNest( 748 rewriter, op.getLoc(), lo, hi, st, {}, 749 [&](OpBuilder &builder, Location loc, ValueRange ivs, 750 ValueRange args) -> scf::ValueVector { 751 Value val; 752 if (isCOOConstant) 753 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 754 ivs, rank); 755 else 756 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 757 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 758 return {}; 759 }); 760 // Final call to construct sparse tensor storage. 761 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 762 params[7] = ptr; 763 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 764 return success(); 765 } 766 }; 767 768 /// Sparse conversion rule for the release operator. 769 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 770 public: 771 using OpConversionPattern::OpConversionPattern; 772 LogicalResult 773 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 774 ConversionPatternRewriter &rewriter) const override { 775 StringRef name = "delSparseTensor"; 776 TypeRange noTp; 777 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 778 EmitCInterface::Off); 779 rewriter.eraseOp(op); 780 return success(); 781 } 782 }; 783 784 /// Sparse conversion rule for pointer accesses. 785 class SparseTensorToPointersConverter 786 : public OpConversionPattern<ToPointersOp> { 787 public: 788 using OpConversionPattern::OpConversionPattern; 789 LogicalResult 790 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 791 ConversionPatternRewriter &rewriter) const override { 792 Type resType = op.getType(); 793 Type eltType = resType.cast<ShapedType>().getElementType(); 794 StringRef name; 795 if (eltType.isIndex()) 796 name = "sparsePointers"; 797 else if (eltType.isInteger(64)) 798 name = "sparsePointers64"; 799 else if (eltType.isInteger(32)) 800 name = "sparsePointers32"; 801 else if (eltType.isInteger(16)) 802 name = "sparsePointers16"; 803 else if (eltType.isInteger(8)) 804 name = "sparsePointers8"; 805 else 806 return failure(); 807 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 808 EmitCInterface::On); 809 return success(); 810 } 811 }; 812 813 /// Sparse conversion rule for index accesses. 814 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 815 public: 816 using OpConversionPattern::OpConversionPattern; 817 LogicalResult 818 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 819 ConversionPatternRewriter &rewriter) const override { 820 Type resType = op.getType(); 821 Type eltType = resType.cast<ShapedType>().getElementType(); 822 StringRef name; 823 if (eltType.isIndex()) 824 name = "sparseIndices"; 825 else if (eltType.isInteger(64)) 826 name = "sparseIndices64"; 827 else if (eltType.isInteger(32)) 828 name = "sparseIndices32"; 829 else if (eltType.isInteger(16)) 830 name = "sparseIndices16"; 831 else if (eltType.isInteger(8)) 832 name = "sparseIndices8"; 833 else 834 return failure(); 835 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 836 EmitCInterface::On); 837 return success(); 838 } 839 }; 840 841 /// Sparse conversion rule for value accesses. 842 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 843 public: 844 using OpConversionPattern::OpConversionPattern; 845 LogicalResult 846 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 847 ConversionPatternRewriter &rewriter) const override { 848 Type resType = op.getType(); 849 Type eltType = resType.cast<ShapedType>().getElementType(); 850 StringRef name; 851 if (eltType.isF64()) 852 name = "sparseValuesF64"; 853 else if (eltType.isF32()) 854 name = "sparseValuesF32"; 855 else if (eltType.isInteger(64)) 856 name = "sparseValuesI64"; 857 else if (eltType.isInteger(32)) 858 name = "sparseValuesI32"; 859 else if (eltType.isInteger(16)) 860 name = "sparseValuesI16"; 861 else if (eltType.isInteger(8)) 862 name = "sparseValuesI8"; 863 else 864 return failure(); 865 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 866 EmitCInterface::On); 867 return success(); 868 } 869 }; 870 871 /// Sparse conversion rule for tensor rematerialization. 872 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 873 public: 874 using OpConversionPattern::OpConversionPattern; 875 LogicalResult 876 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 877 ConversionPatternRewriter &rewriter) const override { 878 if (op.hasInserts()) { 879 // Finalize any pending insertions. 880 StringRef name = "endInsert"; 881 TypeRange noTp; 882 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 883 EmitCInterface::Off); 884 } 885 rewriter.replaceOp(op, adaptor.getOperands()); 886 return success(); 887 } 888 }; 889 890 /// Sparse conversion rule for inserting in lexicographic index order. 891 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 892 public: 893 using OpConversionPattern::OpConversionPattern; 894 LogicalResult 895 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 896 ConversionPatternRewriter &rewriter) const override { 897 Type srcType = op.tensor().getType(); 898 Type eltType = srcType.cast<ShapedType>().getElementType(); 899 StringRef name; 900 if (eltType.isF64()) 901 name = "lexInsertF64"; 902 else if (eltType.isF32()) 903 name = "lexInsertF32"; 904 else if (eltType.isInteger(64)) 905 name = "lexInsertI64"; 906 else if (eltType.isInteger(32)) 907 name = "lexInsertI32"; 908 else if (eltType.isInteger(16)) 909 name = "lexInsertI16"; 910 else if (eltType.isInteger(8)) 911 name = "lexInsertI8"; 912 else 913 llvm_unreachable("Unknown element type"); 914 TypeRange noTp; 915 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 916 EmitCInterface::On); 917 return success(); 918 } 919 }; 920 921 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 922 public: 923 using OpConversionPattern::OpConversionPattern; 924 LogicalResult 925 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 926 ConversionPatternRewriter &rewriter) const override { 927 Location loc = op->getLoc(); 928 ShapedType srcType = op.tensor().getType().cast<ShapedType>(); 929 Type eltType = srcType.getElementType(); 930 Type boolType = rewriter.getIntegerType(1); 931 Type idxType = rewriter.getIndexType(); 932 // All initialization should be done on entry of the loop nest. 933 rewriter.setInsertionPointAfter(op.tensor().getDefiningOp()); 934 // Determine the size for access expansion. 935 auto enc = getSparseTensorEncoding(srcType); 936 Value src = adaptor.getOperands()[0]; 937 Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1); 938 // Allocate temporary stack buffers for values, filled-switch, and indices. 939 Value values = genAlloca(rewriter, loc, sz, eltType); 940 Value filled = genAlloca(rewriter, loc, sz, boolType); 941 Value indices = genAlloca(rewriter, loc, sz, idxType); 942 Value zero = constantZero(rewriter, loc, idxType); 943 // Reset the values/filled-switch to all-zero/false. Note that this 944 // introduces an O(N) operation into the computation, but this reset 945 // operation is amortized over the innermost loops for the access 946 // pattern expansion. 947 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType), 948 values); 949 rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType), 950 filled); 951 // Replace expansion op with these buffers and initial index. 952 assert(op.getNumResults() == 4); 953 rewriter.replaceOp(op, {values, filled, indices, zero}); 954 return success(); 955 } 956 }; 957 958 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 959 public: 960 using OpConversionPattern::OpConversionPattern; 961 LogicalResult 962 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 963 ConversionPatternRewriter &rewriter) const override { 964 // Note that this method call resets the values/filled-switch back to 965 // all-zero/false by only iterating over the set elements, so the 966 // complexity remains proportional to the sparsity of the expanded 967 // access pattern. 968 Type srcType = op.tensor().getType(); 969 Type eltType = srcType.cast<ShapedType>().getElementType(); 970 StringRef name; 971 if (eltType.isF64()) 972 name = "expInsertF64"; 973 else if (eltType.isF32()) 974 name = "expInsertF32"; 975 else if (eltType.isInteger(64)) 976 name = "expInsertI64"; 977 else if (eltType.isInteger(32)) 978 name = "expInsertI32"; 979 else if (eltType.isInteger(16)) 980 name = "expInsertI16"; 981 else if (eltType.isInteger(8)) 982 name = "expInsertI8"; 983 else 984 return failure(); 985 TypeRange noTp; 986 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 987 EmitCInterface::On); 988 return success(); 989 } 990 }; 991 992 } // namespace 993 994 //===----------------------------------------------------------------------===// 995 // Public method for populating conversion rules. 996 //===----------------------------------------------------------------------===// 997 998 /// Populates the given patterns list with conversion rules required for 999 /// the sparsification of linear algebra operations. 1000 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 1001 RewritePatternSet &patterns) { 1002 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 1003 SparseCastConverter, SparseTensorNewConverter, 1004 SparseTensorInitConverter, SparseTensorConvertConverter, 1005 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 1006 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 1007 SparseTensorLoadConverter, SparseTensorLexInsertConverter, 1008 SparseTensorExpandConverter, SparseTensorCompressConverter>( 1009 typeConverter, patterns.getContext()); 1010 } 1011