1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Dialect/StandardOps/IR/Ops.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/ExecutionEngine/SparseTensorUtils.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 29 using namespace mlir; 30 using namespace mlir::sparse_tensor; 31 32 namespace { 33 34 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 35 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 36 enum class EmitCInterface : bool { Off = false, On = true }; 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Generates a constant zero of the given type. 43 inline static Value constantZero(ConversionPatternRewriter &rewriter, 44 Location loc, Type t) { 45 return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t)); 46 } 47 48 /// Generates a constant of `index` type. 49 inline static Value constantIndex(ConversionPatternRewriter &rewriter, 50 Location loc, int64_t i) { 51 return rewriter.create<arith::ConstantIndexOp>(loc, i); 52 } 53 54 /// Generates a constant of `i32` type. 55 inline static Value constantI32(ConversionPatternRewriter &rewriter, 56 Location loc, int32_t i) { 57 return rewriter.create<arith::ConstantIntOp>(loc, i, 32); 58 } 59 60 /// Generates a constant of `i8` type. 61 inline static Value constantI8(ConversionPatternRewriter &rewriter, 62 Location loc, int8_t i) { 63 return rewriter.create<arith::ConstantIntOp>(loc, i, 8); 64 } 65 66 /// Generates a constant of the given `Action`. 67 static Value constantAction(ConversionPatternRewriter &rewriter, Location loc, 68 Action action) { 69 return constantI32(rewriter, loc, static_cast<uint32_t>(action)); 70 } 71 72 /// Generates a constant of the internal type encoding for overhead storage. 73 static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, 74 Location loc, unsigned width) { 75 OverheadType sec; 76 switch (width) { 77 default: 78 sec = OverheadType::kU64; 79 break; 80 case 32: 81 sec = OverheadType::kU32; 82 break; 83 case 16: 84 sec = OverheadType::kU16; 85 break; 86 case 8: 87 sec = OverheadType::kU8; 88 break; 89 } 90 return constantI32(rewriter, loc, static_cast<uint32_t>(sec)); 91 } 92 93 /// Generates a constant of the internal type encoding for pointer 94 /// overhead storage. 95 static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, 96 Location loc, 97 SparseTensorEncodingAttr &enc) { 98 return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); 99 } 100 101 /// Generates a constant of the internal type encoding for index overhead 102 /// storage. 103 static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, 104 Location loc, 105 SparseTensorEncodingAttr &enc) { 106 return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); 107 } 108 109 /// Generates a constant of the internal type encoding for primary storage. 110 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, 111 Location loc, Type tp) { 112 PrimaryType primary; 113 if (tp.isF64()) 114 primary = PrimaryType::kF64; 115 else if (tp.isF32()) 116 primary = PrimaryType::kF32; 117 else if (tp.isInteger(64)) 118 primary = PrimaryType::kI64; 119 else if (tp.isInteger(32)) 120 primary = PrimaryType::kI32; 121 else if (tp.isInteger(16)) 122 primary = PrimaryType::kI16; 123 else if (tp.isInteger(8)) 124 primary = PrimaryType::kI8; 125 else 126 llvm_unreachable("Unknown element type"); 127 return constantI32(rewriter, loc, static_cast<uint32_t>(primary)); 128 } 129 130 /// Generates a constant of the internal dimension level type encoding. 131 static Value 132 constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, 133 SparseTensorEncodingAttr::DimLevelType dlt) { 134 DimLevelType dlt2; 135 switch (dlt) { 136 case SparseTensorEncodingAttr::DimLevelType::Dense: 137 dlt2 = DimLevelType::kDense; 138 break; 139 case SparseTensorEncodingAttr::DimLevelType::Compressed: 140 dlt2 = DimLevelType::kCompressed; 141 break; 142 case SparseTensorEncodingAttr::DimLevelType::Singleton: 143 dlt2 = DimLevelType::kSingleton; 144 break; 145 } 146 return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2)); 147 } 148 149 /// Returns the equivalent of `void*` for opaque arguments to the 150 /// execution engine. 151 static Type getOpaquePointerType(PatternRewriter &rewriter) { 152 return LLVM::LLVMPointerType::get(rewriter.getI8Type()); 153 } 154 155 /// Returns a function reference (first hit also inserts into module). Sets 156 /// the "_emit_c_interface" on the function declaration when requested, 157 /// so that LLVM lowering generates a wrapper function that takes care 158 /// of ABI complications with passing in and returning MemRefs to C functions. 159 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, 160 TypeRange resultType, ValueRange operands, 161 EmitCInterface emitCInterface) { 162 MLIRContext *context = op->getContext(); 163 auto module = op->getParentOfType<ModuleOp>(); 164 auto result = SymbolRefAttr::get(context, name); 165 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 166 if (!func) { 167 OpBuilder moduleBuilder(module.getBodyRegion()); 168 func = moduleBuilder.create<FuncOp>( 169 op->getLoc(), name, 170 FunctionType::get(context, operands.getTypes(), resultType)); 171 func.setPrivate(); 172 if (static_cast<bool>(emitCInterface)) 173 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 174 } 175 return result; 176 } 177 178 /// Creates a `CallOp` to the function reference returned by `getFunc()`. 179 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, 180 TypeRange resultType, ValueRange operands, 181 EmitCInterface emitCInterface) { 182 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 183 return builder.create<CallOp>(op->getLoc(), resultType, fn, operands); 184 } 185 186 /// Replaces the `op` with a `CallOp` to the function reference returned 187 /// by `getFunc()`. 188 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, 189 StringRef name, TypeRange resultType, 190 ValueRange operands, 191 EmitCInterface emitCInterface) { 192 auto fn = getFunc(op, name, resultType, operands, emitCInterface); 193 return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands); 194 } 195 196 /// Generates dimension size call. 197 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, 198 SparseTensorEncodingAttr &enc, Value src, 199 int64_t idx) { 200 // Permute the index according to an optional dimension ordering. 201 if (AffineMap p = enc.getDimOrdering()) 202 idx = p.getPermutedPosition(idx); 203 // Generate the call. 204 StringRef name = "sparseDimSize"; 205 SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)}; 206 Type iTp = rewriter.getIndexType(); 207 return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) 208 .getResult(0); 209 } 210 211 /// Generates a call into the "swiss army knife" method of the sparse runtime 212 /// support library for materializing sparse tensors into the computation. 213 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 214 ArrayRef<Value> params) { 215 StringRef name = "newSparseTensor"; 216 Type pTp = getOpaquePointerType(rewriter); 217 return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) 218 .getResult(0); 219 } 220 221 /// Populates given sizes array from type. 222 static void sizesFromType(ConversionPatternRewriter &rewriter, 223 SmallVector<Value, 4> &sizes, Location loc, 224 ShapedType stp) { 225 auto shape = stp.getShape(); 226 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { 227 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 228 sizes.push_back(constantIndex(rewriter, loc, s)); 229 } 230 } 231 232 /// Populates given sizes array from source. 233 static void sizesFromSrc(ConversionPatternRewriter &rewriter, 234 SmallVector<Value, 4> &sizes, Location loc, 235 Value src) { 236 unsigned rank = src.getType().cast<ShapedType>().getRank(); 237 for (unsigned i = 0; i < rank; i++) 238 sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 239 } 240 241 /// Populates given sizes array from type (for static sizes) and from 242 /// an already converted into opague pointer source (for dynamic sizes). 243 static void sizesFromPtr(ConversionPatternRewriter &rewriter, 244 SmallVector<Value, 4> &sizes, Operation *op, 245 SparseTensorEncodingAttr &enc, ShapedType stp, 246 Value src) { 247 Location loc = op->getLoc(); 248 auto shape = stp.getShape(); 249 for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) 250 if (shape[i] == ShapedType::kDynamicSize) 251 sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); 252 else 253 sizes.push_back(constantIndex(rewriter, loc, shape[i])); 254 } 255 256 /// Generates an uninitialized temporary buffer of the given size and 257 /// type, but returns it as type `memref<? x $tp>` (rather than as type 258 /// `memref<$sz x $tp>`). 259 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, 260 unsigned sz, Type tp) { 261 auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); 262 Value a = constantIndex(rewriter, loc, sz); 263 return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a}); 264 } 265 266 /// Generates an uninitialized temporary buffer with room for one value 267 /// of the given type, and returns the `memref<$tp>`. 268 static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, 269 Type tp) { 270 return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 271 } 272 273 /// Generates a temporary buffer of the given type and given contents. 274 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, 275 ArrayRef<Value> values) { 276 unsigned sz = values.size(); 277 assert(sz >= 1); 278 Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); 279 for (unsigned i = 0; i < sz; i++) { 280 Value idx = constantIndex(rewriter, loc, i); 281 rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx); 282 } 283 return buffer; 284 } 285 286 /// Populates parameters required to call the "swiss army knife" method of the 287 /// sparse runtime support library for materializing sparse tensors into the 288 /// computation. 289 static void newParams(ConversionPatternRewriter &rewriter, 290 SmallVector<Value, 8> ¶ms, Operation *op, 291 SparseTensorEncodingAttr &enc, Action action, 292 ValueRange szs, Value ptr = Value()) { 293 Location loc = op->getLoc(); 294 ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType(); 295 unsigned sz = dlt.size(); 296 // Sparsity annotations. 297 SmallVector<Value, 4> attrs; 298 for (unsigned i = 0; i < sz; i++) 299 attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); 300 params.push_back(genBuffer(rewriter, loc, attrs)); 301 // Dimension sizes array of the enveloping tensor. Useful for either 302 // verification of external data, or for construction of internal data. 303 SmallVector<Value, 4> sizes; 304 for (Value s : szs) 305 sizes.push_back(s); 306 params.push_back(genBuffer(rewriter, loc, sizes)); 307 // Dimension order permutation array. This is the "identity" permutation by 308 // default, or otherwise the "reverse" permutation of a given ordering, so 309 // that indices can be mapped quickly to the right position. 310 SmallVector<Value, 4> rev(sz); 311 if (AffineMap p = enc.getDimOrdering()) { 312 for (unsigned i = 0; i < sz; i++) 313 rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); 314 } else { 315 for (unsigned i = 0; i < sz; i++) 316 rev[i] = constantIndex(rewriter, loc, i); 317 } 318 params.push_back(genBuffer(rewriter, loc, rev)); 319 // Secondary and primary types encoding. 320 Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType(); 321 params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); 322 params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); 323 params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); 324 // User action. 325 params.push_back(constantAction(rewriter, loc, action)); 326 // Payload pointer. 327 if (!ptr) 328 ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter)); 329 params.push_back(ptr); 330 } 331 332 /// Generates the comparison `v != 0` where `v` is of numeric type `t`. 333 /// For floating types, we use the "unordered" comparator (i.e., returns 334 /// true if `v` is NaN). 335 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, 336 Value v) { 337 Type t = v.getType(); 338 Value zero = constantZero(rewriter, loc, t); 339 if (t.isa<FloatType>()) 340 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 341 zero); 342 if (t.isIntOrIndex()) 343 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 344 zero); 345 llvm_unreachable("Unknown element type"); 346 } 347 348 /// Generates the code to read the value from tensor[ivs], and conditionally 349 /// stores the indices ivs to the memory in ind. The generated code looks like 350 /// the following and the insertion point after this routine is inside the 351 /// if-then branch behind the assignment to ind. This is to ensure that the 352 /// addEltX call generated after is inside the if-then branch. 353 /// if (tensor[ivs]!=0) { 354 /// ind = ivs 355 static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, 356 Location loc, Value tensor, Value ind, 357 ValueRange ivs) { 358 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 359 Value cond = genIsNonzero(rewriter, loc, val); 360 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false); 361 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 362 unsigned i = 0; 363 for (auto iv : ivs) { 364 Value idx = constantIndex(rewriter, loc, i++); 365 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 366 } 367 return val; 368 } 369 370 /// Generates a call that adds one element to a coordinate scheme. 371 /// In particular, this generates code like the following: 372 /// val = a[i1,..,ik]; 373 /// if val != 0 374 /// t->add(val, [i1,..,ik], [p1,..,pk]); 375 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 376 Type eltType, Value ptr, Value val, Value ind, 377 Value perm) { 378 StringRef name; 379 if (eltType.isF64()) 380 name = "addEltF64"; 381 else if (eltType.isF32()) 382 name = "addEltF32"; 383 else if (eltType.isInteger(64)) 384 name = "addEltI64"; 385 else if (eltType.isInteger(32)) 386 name = "addEltI32"; 387 else if (eltType.isInteger(16)) 388 name = "addEltI16"; 389 else if (eltType.isInteger(8)) 390 name = "addEltI8"; 391 else 392 llvm_unreachable("Unknown element type"); 393 SmallVector<Value, 4> params{ptr, val, ind, perm}; 394 Type pTp = getOpaquePointerType(rewriter); 395 createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); 396 } 397 398 /// Generates a call to `iter->getNext()`. If there is a next element, 399 /// then it is copied into the out-parameters `ind` and `elemPtr`, 400 /// and the return value is true. If there isn't a next element, then 401 /// the memory for `iter` is freed and the return value is false. 402 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, 403 Value iter, Value ind, Value elemPtr) { 404 Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType(); 405 StringRef name; 406 if (elemTp.isF64()) 407 name = "getNextF64"; 408 else if (elemTp.isF32()) 409 name = "getNextF32"; 410 else if (elemTp.isInteger(64)) 411 name = "getNextI64"; 412 else if (elemTp.isInteger(32)) 413 name = "getNextI32"; 414 else if (elemTp.isInteger(16)) 415 name = "getNextI16"; 416 else if (elemTp.isInteger(8)) 417 name = "getNextI8"; 418 else 419 llvm_unreachable("Unknown element type"); 420 SmallVector<Value, 3> params{iter, ind, elemPtr}; 421 Type i1 = rewriter.getI1Type(); 422 return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) 423 .getResult(0); 424 } 425 426 /// If the tensor is a sparse constant, generates and returns the pair of 427 /// the constants for the indices and the values. 428 static Optional<std::pair<Value, Value>> 429 genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, 430 Value tensor) { 431 if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) { 432 if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) { 433 DenseElementsAttr indicesAttr = attr.getIndices(); 434 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 435 DenseElementsAttr valuesAttr = attr.getValues(); 436 Value values = rewriter.create<arith::ConstantOp>(loc, valuesAttr); 437 return std::make_pair(indices, values); 438 } 439 } 440 return {}; 441 } 442 443 /// Generates the code to copy the index at indices[ivs] to ind, and return 444 /// the value at value[ivs]. 445 static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, 446 Location loc, Value indices, 447 Value values, Value ind, ValueRange ivs, 448 unsigned rank) { 449 for (unsigned i = 0; i < rank; i++) { 450 Value idx = constantIndex(rewriter, loc, i); 451 Value val = rewriter.create<tensor::ExtractOp>(loc, indices, 452 ValueRange{ivs[0], idx}); 453 val = 454 rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType()); 455 rewriter.create<memref::StoreOp>(loc, val, ind, idx); 456 } 457 return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]); 458 } 459 460 /// Generates code to allocate a tensor of the given type, and zero 461 /// initialize it. If the tensor type has any dynamic sizes, then the 462 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 463 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 464 static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, 465 RankedTensorType tensorTp, ValueRange sizes) { 466 Type elemTp = tensorTp.getElementType(); 467 auto shape = tensorTp.getShape(); 468 auto memTp = MemRefType::get(shape, elemTp); 469 SmallVector<Value> dynamicSizes; 470 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 471 if (shape[i] == ShapedType::kDynamicSize) 472 dynamicSizes.push_back(sizes[i]); 473 } 474 Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes); 475 Value zero = constantZero(rewriter, loc, elemTp); 476 rewriter.create<linalg::FillOp>(loc, zero, mem); 477 return mem; 478 } 479 480 /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into 481 /// the tensor created by allocDenseTensor(). The `rank` is the rank 482 /// of the `tensor` and the length of `ind`. 483 static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, 484 Location loc, Value elemPtr, 485 Value tensor, unsigned rank, 486 Value ind) { 487 SmallVector<Value, 4> ivs; 488 ivs.reserve(rank); 489 for (unsigned i = 0; i < rank; i++) { 490 Value idx = constantIndex(rewriter, loc, i); 491 ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx)); 492 } 493 Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr); 494 rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs); 495 } 496 497 //===----------------------------------------------------------------------===// 498 // Conversion rules. 499 //===----------------------------------------------------------------------===// 500 501 /// Sparse conversion rule for returns. 502 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 503 public: 504 using OpConversionPattern::OpConversionPattern; 505 LogicalResult 506 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 507 ConversionPatternRewriter &rewriter) const override { 508 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 509 return success(); 510 } 511 }; 512 513 /// Sparse conversion rule for dimension accesses. 514 class SparseTensorToDimSizeConverter 515 : public OpConversionPattern<tensor::DimOp> { 516 public: 517 using OpConversionPattern::OpConversionPattern; 518 LogicalResult 519 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 520 ConversionPatternRewriter &rewriter) const override { 521 // Only rewrite annotated DimOp with constant index. 522 auto enc = getSparseTensorEncoding(op.source().getType()); 523 if (!enc) 524 return failure(); 525 Optional<int64_t> index = op.getConstantIndex(); 526 if (!index.hasValue()) 527 return failure(); 528 // Generate the call. 529 Value src = adaptor.getOperands()[0]; 530 int64_t idx = index.getValue(); 531 rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx)); 532 return success(); 533 } 534 }; 535 536 /// Sparse conversion rule for trivial tensor casts. 537 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 538 using OpConversionPattern::OpConversionPattern; 539 LogicalResult 540 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 541 ConversionPatternRewriter &rewriter) const override { 542 // Only rewrite identically annotated source/dest. 543 auto encDst = getSparseTensorEncoding(op.getType()); 544 auto encSrc = getSparseTensorEncoding(op.source().getType()); 545 if (!encDst || encDst != encSrc) 546 return failure(); 547 rewriter.replaceOp(op, adaptor.getOperands()); 548 return success(); 549 } 550 }; 551 552 /// Sparse conversion rule for the new operator. 553 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 554 using OpConversionPattern::OpConversionPattern; 555 LogicalResult 556 matchAndRewrite(NewOp op, OpAdaptor adaptor, 557 ConversionPatternRewriter &rewriter) const override { 558 Type resType = op.getType(); 559 auto enc = getSparseTensorEncoding(resType); 560 if (!enc) 561 return failure(); 562 // Generate the call to construct tensor from ptr. The sizes are 563 // inferred from the result type of the new operator. 564 SmallVector<Value, 4> sizes; 565 SmallVector<Value, 8> params; 566 sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>()); 567 Value ptr = adaptor.getOperands()[0]; 568 newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr); 569 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 570 return success(); 571 } 572 }; 573 574 /// Sparse conversion rule for the init operator. 575 class SparseTensorInitConverter : public OpConversionPattern<InitOp> { 576 using OpConversionPattern::OpConversionPattern; 577 LogicalResult 578 matchAndRewrite(InitOp op, OpAdaptor adaptor, 579 ConversionPatternRewriter &rewriter) const override { 580 Type resType = op.getType(); 581 auto enc = getSparseTensorEncoding(resType); 582 if (!enc) 583 return failure(); 584 // Generate the call to construct empty tensor. The sizes are 585 // explicitly defined by the arguments to the init operator. 586 SmallVector<Value, 8> params; 587 newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands()); 588 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 589 return success(); 590 } 591 }; 592 593 /// Sparse conversion rule for the convert operator. 594 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 595 using OpConversionPattern::OpConversionPattern; 596 LogicalResult 597 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 598 ConversionPatternRewriter &rewriter) const override { 599 Location loc = op->getLoc(); 600 Type resType = op.getType(); 601 Type srcType = op.source().getType(); 602 auto encDst = getSparseTensorEncoding(resType); 603 auto encSrc = getSparseTensorEncoding(srcType); 604 Value src = adaptor.getOperands()[0]; 605 if (encDst && encSrc) { 606 // This is a sparse => sparse conversion, which is handled as follows: 607 // t = src->toCOO(); ; src to COO in dst order 608 // dst = newSparseTensor(t) 609 // Using the coordinate scheme as an intermediate does not always 610 // yield the fastest conversion but avoids the need for a full 611 // O(N^2) conversion matrix. 612 if (encDst == encSrc) { 613 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 614 return success(); 615 } 616 SmallVector<Value, 4> sizes; 617 SmallVector<Value, 8> params; 618 sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(), 619 src); 620 // Set up encoding with right mix of src and dst so that the two 621 // method calls can share most parameters, while still providing 622 // the correct sparsity information to either of them. 623 auto enc = SparseTensorEncodingAttr::get( 624 op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), 625 encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 626 newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); 627 Value coo = genNewCall(rewriter, op, params); 628 params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); 629 params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); 630 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 631 params[7] = coo; 632 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 633 return success(); 634 } 635 if (!encDst && encSrc) { 636 // This is sparse => dense conversion, which is handled as follows: 637 // dst = new Tensor(0); 638 // iter = src->toCOO(); 639 // iter->startIterator(); 640 // while (elem = iter->getNext()) { 641 // dst[elem.indices] = elem.value; 642 // } 643 RankedTensorType dstTensorTp = resType.cast<RankedTensorType>(); 644 RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>(); 645 unsigned rank = dstTensorTp.getRank(); 646 Type elemTp = dstTensorTp.getElementType(); 647 // Fabricate a no-permutation encoding for newParams(). 648 // The pointer/index types must be those of `src`. 649 // The dimLevelTypes aren't actually used by Action::kToIterator. 650 encDst = SparseTensorEncodingAttr::get( 651 op->getContext(), 652 SmallVector<SparseTensorEncodingAttr::DimLevelType>( 653 rank, SparseTensorEncodingAttr::DimLevelType::Dense), 654 AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); 655 SmallVector<Value, 4> sizes; 656 SmallVector<Value, 8> params; 657 sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src); 658 newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src); 659 Value iter = genNewCall(rewriter, op, params); 660 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 661 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 662 Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, sizes); 663 SmallVector<Value> noArgs; 664 SmallVector<Type> noTypes; 665 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 666 Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes); 667 rewriter.setInsertionPointToEnd(before); 668 Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr); 669 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 670 Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes); 671 rewriter.setInsertionPointToStart(after); 672 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind); 673 rewriter.create<scf::YieldOp>(loc); 674 rewriter.setInsertionPointAfter(whileOp); 675 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, dst); 676 return success(); 677 } 678 if (!encDst && !encSrc) { 679 // dense => dense 680 return failure(); 681 } 682 // This is a dense => sparse conversion or a sparse constant in COO => 683 // sparse conversion, which is handled as follows: 684 // t = newSparseCOO() 685 // ...code to fill the COO tensor t... 686 // s = newSparseTensor(t) 687 // 688 // To fill the COO tensor from a dense tensor: 689 // for i1 in dim1 690 // .. 691 // for ik in dimk 692 // val = a[i1,..,ik] 693 // if val != 0 694 // t->add(val, [i1,..,ik], [p1,..,pk]) 695 // 696 // To fill the COO tensor from a sparse constant in COO format: 697 // for i in range(NNZ) 698 // val = values[i] 699 // [i1,..,ik] = indices[i] 700 // t->add(val, [i1,..,ik], [p1,..,pk]) 701 // 702 // Note that the dense tensor traversal code is actually implemented 703 // using MLIR IR to avoid having to expose too much low-level 704 // memref traversal details to the runtime support library. 705 // Also note that the code below only generates the "new" ops and 706 // the loop-nest per se; whereas the entire body of the innermost 707 // loop is generated by genAddElt(). 708 ShapedType stp = resType.cast<ShapedType>(); 709 unsigned rank = stp.getRank(); 710 SmallVector<Value, 4> sizes; 711 SmallVector<Value, 8> params; 712 sizesFromSrc(rewriter, sizes, loc, src); 713 newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes); 714 Value ptr = genNewCall(rewriter, op, params); 715 Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); 716 Value perm = params[2]; 717 SmallVector<Value> lo; 718 SmallVector<Value> hi; 719 SmallVector<Value> st; 720 Value zero = constantIndex(rewriter, loc, 0); 721 Value one = constantIndex(rewriter, loc, 1); 722 auto indicesValues = genSplitSparseConstant(rewriter, loc, src); 723 bool isCOOConstant = indicesValues.hasValue(); 724 Value indices; 725 Value values; 726 if (isCOOConstant) { 727 indices = indicesValues->first; 728 values = indicesValues->second; 729 lo.push_back(zero); 730 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); 731 st.push_back(one); 732 } else { 733 for (unsigned i = 0; i < rank; i++) { 734 lo.push_back(zero); 735 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); 736 st.push_back(one); 737 } 738 } 739 Type eltType = stp.getElementType(); 740 scf::buildLoopNest( 741 rewriter, op.getLoc(), lo, hi, st, {}, 742 [&](OpBuilder &builder, Location loc, ValueRange ivs, 743 ValueRange args) -> scf::ValueVector { 744 Value val; 745 if (isCOOConstant) 746 val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind, 747 ivs, rank); 748 else 749 val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs); 750 genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); 751 return {}; 752 }); 753 // Final call to construct sparse tensor storage. 754 params[6] = constantAction(rewriter, loc, Action::kFromCOO); 755 params[7] = ptr; 756 rewriter.replaceOp(op, genNewCall(rewriter, op, params)); 757 return success(); 758 } 759 }; 760 761 /// Sparse conversion rule for the release operator. 762 class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> { 763 public: 764 using OpConversionPattern::OpConversionPattern; 765 LogicalResult 766 matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, 767 ConversionPatternRewriter &rewriter) const override { 768 StringRef name = "delSparseTensor"; 769 TypeRange noTp; 770 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 771 EmitCInterface::Off); 772 rewriter.eraseOp(op); 773 return success(); 774 } 775 }; 776 777 /// Sparse conversion rule for pointer accesses. 778 class SparseTensorToPointersConverter 779 : public OpConversionPattern<ToPointersOp> { 780 public: 781 using OpConversionPattern::OpConversionPattern; 782 LogicalResult 783 matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, 784 ConversionPatternRewriter &rewriter) const override { 785 Type resType = op.getType(); 786 Type eltType = resType.cast<ShapedType>().getElementType(); 787 StringRef name; 788 if (eltType.isIndex()) 789 name = "sparsePointers"; 790 else if (eltType.isInteger(64)) 791 name = "sparsePointers64"; 792 else if (eltType.isInteger(32)) 793 name = "sparsePointers32"; 794 else if (eltType.isInteger(16)) 795 name = "sparsePointers16"; 796 else if (eltType.isInteger(8)) 797 name = "sparsePointers8"; 798 else 799 return failure(); 800 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 801 EmitCInterface::On); 802 return success(); 803 } 804 }; 805 806 /// Sparse conversion rule for index accesses. 807 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 808 public: 809 using OpConversionPattern::OpConversionPattern; 810 LogicalResult 811 matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, 812 ConversionPatternRewriter &rewriter) const override { 813 Type resType = op.getType(); 814 Type eltType = resType.cast<ShapedType>().getElementType(); 815 StringRef name; 816 if (eltType.isIndex()) 817 name = "sparseIndices"; 818 else if (eltType.isInteger(64)) 819 name = "sparseIndices64"; 820 else if (eltType.isInteger(32)) 821 name = "sparseIndices32"; 822 else if (eltType.isInteger(16)) 823 name = "sparseIndices16"; 824 else if (eltType.isInteger(8)) 825 name = "sparseIndices8"; 826 else 827 return failure(); 828 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 829 EmitCInterface::On); 830 return success(); 831 } 832 }; 833 834 /// Sparse conversion rule for value accesses. 835 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 836 public: 837 using OpConversionPattern::OpConversionPattern; 838 LogicalResult 839 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 840 ConversionPatternRewriter &rewriter) const override { 841 Type resType = op.getType(); 842 Type eltType = resType.cast<ShapedType>().getElementType(); 843 StringRef name; 844 if (eltType.isF64()) 845 name = "sparseValuesF64"; 846 else if (eltType.isF32()) 847 name = "sparseValuesF32"; 848 else if (eltType.isInteger(64)) 849 name = "sparseValuesI64"; 850 else if (eltType.isInteger(32)) 851 name = "sparseValuesI32"; 852 else if (eltType.isInteger(16)) 853 name = "sparseValuesI16"; 854 else if (eltType.isInteger(8)) 855 name = "sparseValuesI8"; 856 else 857 return failure(); 858 replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), 859 EmitCInterface::On); 860 return success(); 861 } 862 }; 863 864 /// Sparse conversion rule for tensor rematerialization. 865 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 866 public: 867 using OpConversionPattern::OpConversionPattern; 868 LogicalResult 869 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 870 ConversionPatternRewriter &rewriter) const override { 871 if (op.hasInserts()) { 872 // Finalize any pending insertions. 873 StringRef name = "endInsert"; 874 TypeRange noTp; 875 createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 876 EmitCInterface::Off); 877 } 878 rewriter.replaceOp(op, adaptor.getOperands()); 879 return success(); 880 } 881 }; 882 883 /// Sparse conversion rule for inserting in lexicographic index order. 884 class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> { 885 public: 886 using OpConversionPattern::OpConversionPattern; 887 LogicalResult 888 matchAndRewrite(LexInsertOp op, OpAdaptor adaptor, 889 ConversionPatternRewriter &rewriter) const override { 890 Type srcType = op.tensor().getType(); 891 Type eltType = srcType.cast<ShapedType>().getElementType(); 892 StringRef name; 893 if (eltType.isF64()) 894 name = "lexInsertF64"; 895 else if (eltType.isF32()) 896 name = "lexInsertF32"; 897 else if (eltType.isInteger(64)) 898 name = "lexInsertI64"; 899 else if (eltType.isInteger(32)) 900 name = "lexInsertI32"; 901 else if (eltType.isInteger(16)) 902 name = "lexInsertI16"; 903 else if (eltType.isInteger(8)) 904 name = "lexInsertI8"; 905 else 906 llvm_unreachable("Unknown element type"); 907 TypeRange noTp; 908 replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), 909 EmitCInterface::On); 910 return success(); 911 } 912 }; 913 914 } // namespace 915 916 //===----------------------------------------------------------------------===// 917 // Public method for populating conversion rules. 918 //===----------------------------------------------------------------------===// 919 920 /// Populates the given patterns list with conversion rules required for 921 /// the sparsification of linear algebra operations. 922 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 923 RewritePatternSet &patterns) { 924 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 925 SparseCastConverter, SparseTensorNewConverter, 926 SparseTensorInitConverter, SparseTensorConvertConverter, 927 SparseTensorReleaseConverter, SparseTensorToPointersConverter, 928 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 929 SparseTensorLoadConverter, SparseTensorLexInsertConverter>( 930 typeConverter, patterns.getContext()); 931 } 932