1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 //===----------------------------------------------------------------------===// 33 // Helper methods. 34 //===----------------------------------------------------------------------===// 35 36 /// Returns internal type encoding for primary storage. Keep these 37 /// values consistent with the sparse runtime support library. 38 static unsigned getPrimaryTypeEncoding(Type tp) { 39 if (tp.isF64()) 40 return 1; 41 if (tp.isF32()) 42 return 2; 43 if (tp.isInteger(64)) 44 return 3; 45 if (tp.isInteger(32)) 46 return 4; 47 if (tp.isInteger(16)) 48 return 5; 49 if (tp.isInteger(8)) 50 return 6; 51 return 0; 52 } 53 54 /// Returns internal type encoding for overhead storage. Keep these 55 /// values consistent with the sparse runtime support library. 56 static unsigned getOverheadTypeEncoding(unsigned width) { 57 switch (width) { 58 default: 59 return 1; 60 case 32: 61 return 2; 62 case 16: 63 return 3; 64 case 8: 65 return 4; 66 } 67 } 68 69 /// Returns internal dimension level type encoding. Keep these 70 /// values consistent with the sparse runtime support library. 71 static unsigned 72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 73 switch (dlt) { 74 case SparseTensorEncodingAttr::DimLevelType::Dense: 75 return 0; 76 case SparseTensorEncodingAttr::DimLevelType::Compressed: 77 return 1; 78 case SparseTensorEncodingAttr::DimLevelType::Singleton: 79 return 2; 80 } 81 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 82 } 83 84 /// Returns integers of given width and values as a constant tensor. 85 /// We cast the static shape into a dynamic shape to ensure that the 86 /// method signature remains uniform across different tensor dimensions. 87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 88 Location loc, ArrayRef<APInt> values) { 89 Type etp = rewriter.getIntegerType(width); 90 unsigned sz = values.size(); 91 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 92 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 93 auto elts = 94 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 95 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 96 } 97 98 /// Returns a function reference (first hit also inserts into module). Sets 99 /// the "_emit_c_interface" on the function declaration when requested, 100 /// so that LLVM lowering generates a wrapper function that takes care 101 /// of ABI complications with passing in and returning MemRefs to C functions. 102 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType, 103 ValueRange operands, 104 bool emitCInterface = false) { 105 MLIRContext *context = op->getContext(); 106 auto module = op->getParentOfType<ModuleOp>(); 107 auto result = SymbolRefAttr::get(context, name); 108 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 109 if (!func) { 110 OpBuilder moduleBuilder(module.getBodyRegion()); 111 func = moduleBuilder.create<FuncOp>( 112 op->getLoc(), name, 113 FunctionType::get(context, operands.getTypes(), resultType)); 114 func.setPrivate(); 115 if (emitCInterface) 116 func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); 117 } 118 return result; 119 } 120 121 /// Generates a call into the "swiss army knife" method of the sparse runtime 122 /// support library for materializing sparse tensors into the computation. The 123 /// method returns the call value and assigns the permutation to 'perm'. 124 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 125 SparseTensorEncodingAttr &enc, uint32_t action, 126 Value &perm, Value ptr = Value()) { 127 Location loc = op->getLoc(); 128 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 129 SmallVector<Value, 8> params; 130 // Sparsity annotations in tensor constant form. 131 SmallVector<APInt, 4> attrs; 132 unsigned sz = enc.getDimLevelType().size(); 133 for (unsigned i = 0; i < sz; i++) 134 attrs.push_back( 135 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 136 params.push_back(getTensor(rewriter, 8, loc, attrs)); 137 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 138 // verification of external data, or for construction of internal data. 139 auto shape = resType.getShape(); 140 SmallVector<APInt, 4> sizes; 141 for (unsigned i = 0; i < sz; i++) { 142 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 143 sizes.push_back(APInt(64, s)); 144 } 145 params.push_back(getTensor(rewriter, 64, loc, sizes)); 146 // Dimension order permutation array. This is the "identity" permutation by 147 // default, or otherwise the "reverse" permutation of a given ordering, so 148 // that indices can be mapped quickly to the right position. 149 SmallVector<APInt, 4> rev(sz); 150 if (AffineMap p = enc.getDimOrdering()) { 151 for (unsigned i = 0; i < sz; i++) 152 rev[p.getDimPosition(i)] = APInt(64, i); 153 } else { 154 for (unsigned i = 0; i < sz; i++) 155 rev[i] = APInt(64, i); 156 } 157 perm = getTensor(rewriter, 64, loc, rev); 158 params.push_back(perm); 159 // Secondary and primary types encoding. 160 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 161 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 162 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 163 assert(primary); 164 params.push_back( 165 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 166 params.push_back( 167 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 168 params.push_back( 169 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 170 // User action and pointer. 171 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 172 if (!ptr) 173 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 174 params.push_back( 175 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action))); 176 params.push_back(ptr); 177 // Generate the call to create new tensor. 178 StringRef name = "newSparseTensor"; 179 auto call = rewriter.create<CallOp>( 180 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 181 params); 182 return call.getResult(0); 183 } 184 185 /// Generates a call that adds one element to a coordinate scheme. 186 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 187 Value ptr, Value tensor, Value ind, Value perm, 188 ValueRange ivs) { 189 Location loc = op->getLoc(); 190 StringRef name; 191 Type eltType = tensor.getType().cast<ShapedType>().getElementType(); 192 if (eltType.isF64()) 193 name = "addEltF64"; 194 else if (eltType.isF32()) 195 name = "addEltF32"; 196 else if (eltType.isInteger(64)) 197 name = "addEltI64"; 198 else if (eltType.isInteger(32)) 199 name = "addEltI32"; 200 else if (eltType.isInteger(16)) 201 name = "addEltI16"; 202 else if (eltType.isInteger(8)) 203 name = "addEltI8"; 204 else 205 llvm_unreachable("Unknown element type"); 206 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 207 // TODO: add if here? 208 unsigned i = 0; 209 for (auto iv : ivs) { 210 Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++)); 211 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 212 } 213 SmallVector<Value, 8> params; 214 params.push_back(ptr); 215 params.push_back(val); 216 params.push_back(ind); 217 params.push_back(perm); 218 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 219 rewriter.create<CallOp>( 220 loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true), 221 params); 222 } 223 224 //===----------------------------------------------------------------------===// 225 // Conversion rules. 226 //===----------------------------------------------------------------------===// 227 228 /// Sparse conversion rule for returns. 229 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 230 public: 231 using OpConversionPattern::OpConversionPattern; 232 LogicalResult 233 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 234 ConversionPatternRewriter &rewriter) const override { 235 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 236 return success(); 237 } 238 }; 239 240 /// Sparse conversion rule for dimension accesses. 241 class SparseTensorToDimSizeConverter 242 : public OpConversionPattern<tensor::DimOp> { 243 public: 244 using OpConversionPattern::OpConversionPattern; 245 LogicalResult 246 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 247 ConversionPatternRewriter &rewriter) const override { 248 Type resType = op.getType(); 249 auto enc = getSparseTensorEncoding(op.source().getType()); 250 if (!enc) 251 return failure(); 252 // Permute the dim index. 253 Optional<int64_t> index = op.getConstantIndex(); 254 if (!index.hasValue()) 255 return failure(); 256 int64_t idx = index.getValue(); 257 if (AffineMap p = enc.getDimOrdering()) 258 idx = p.getPermutedPosition(idx); 259 // Generate the call. 260 StringRef name = "sparseDimSize"; 261 SmallVector<Value, 2> params; 262 params.push_back(operands[0]); 263 params.push_back( 264 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx))); 265 rewriter.replaceOpWithNewOp<CallOp>( 266 op, resType, getFunc(op, name, resType, params), params); 267 return success(); 268 } 269 }; 270 271 /// Sparse conversion rule for the new operator. 272 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 273 using OpConversionPattern::OpConversionPattern; 274 LogicalResult 275 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 276 ConversionPatternRewriter &rewriter) const override { 277 Type resType = op.getType(); 278 auto enc = getSparseTensorEncoding(resType); 279 if (!enc) 280 return failure(); 281 Value perm; 282 rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); 283 return success(); 284 } 285 }; 286 287 /// Sparse conversion rule for the convert operator. 288 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 289 using OpConversionPattern::OpConversionPattern; 290 LogicalResult 291 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 292 ConversionPatternRewriter &rewriter) const override { 293 Type resType = op.getType(); 294 auto encDst = getSparseTensorEncoding(resType); 295 auto encSrc = getSparseTensorEncoding(op.source().getType()); 296 if (encDst && encSrc) { 297 // This is a sparse => sparse conversion, which is handled as follows: 298 // t = src->asCOO(); ; src to COO in dst order 299 // dst = newSparseTensor(t) 300 // Using the coordinate scheme as an intermediate does not always 301 // yield the fastest conversion but avoids the need for a full 302 // O(N^2) conversion matrix. 303 Value perm; 304 Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]); 305 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); 306 return success(); 307 } 308 if (!encDst || encSrc) { 309 // TODO: sparse => dense 310 return failure(); 311 } 312 // This is a dense => sparse conversion, which is handled as follows: 313 // t = newSparseCOO() 314 // for i1 in dim1 315 // .. 316 // for ik in dimk 317 // val = a[i1,..,ik] 318 // if val != 0 319 // t->add(val, [i1,..,ik], [p1,..,pk]) 320 // s = newSparseTensor(t) 321 // Note that the dense tensor traversal code is actually implemented 322 // using MLIR IR to avoid having to expose too much low-level 323 // memref traversal details to the runtime support library. 324 Location loc = op->getLoc(); 325 ShapedType shape = resType.cast<ShapedType>(); 326 auto memTp = 327 MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); 328 Value perm; 329 Value ptr = genNewCall(rewriter, op, encDst, 2, perm); 330 Value tensor = operands[0]; 331 Value arg = rewriter.create<ConstantOp>( 332 loc, rewriter.getIndexAttr(shape.getRank())); 333 Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg}); 334 SmallVector<Value> lo; 335 SmallVector<Value> hi; 336 SmallVector<Value> st; 337 Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0)); 338 Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1)); 339 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 340 lo.push_back(zero); 341 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); 342 st.push_back(one); 343 } 344 scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {}, 345 [&](OpBuilder &builder, Location loc, ValueRange ivs, 346 ValueRange args) -> scf::ValueVector { 347 genAddEltCall(rewriter, op, ptr, tensor, ind, perm, 348 ivs); 349 return {}; 350 }); 351 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); 352 return success(); 353 } 354 }; 355 356 /// Sparse conversion rule for pointer accesses. 357 class SparseTensorToPointersConverter 358 : public OpConversionPattern<ToPointersOp> { 359 public: 360 using OpConversionPattern::OpConversionPattern; 361 LogicalResult 362 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 363 ConversionPatternRewriter &rewriter) const override { 364 Type resType = op.getType(); 365 Type eltType = resType.cast<ShapedType>().getElementType(); 366 StringRef name; 367 if (eltType.isIndex()) 368 name = "sparsePointers"; 369 else if (eltType.isInteger(64)) 370 name = "sparsePointers64"; 371 else if (eltType.isInteger(32)) 372 name = "sparsePointers32"; 373 else if (eltType.isInteger(16)) 374 name = "sparsePointers16"; 375 else if (eltType.isInteger(8)) 376 name = "sparsePointers8"; 377 else 378 return failure(); 379 rewriter.replaceOpWithNewOp<CallOp>( 380 op, resType, 381 getFunc(op, name, resType, operands, /*emitCInterface=*/true), 382 operands); 383 return success(); 384 } 385 }; 386 387 /// Sparse conversion rule for index accesses. 388 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 389 public: 390 using OpConversionPattern::OpConversionPattern; 391 LogicalResult 392 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 393 ConversionPatternRewriter &rewriter) const override { 394 Type resType = op.getType(); 395 Type eltType = resType.cast<ShapedType>().getElementType(); 396 StringRef name; 397 if (eltType.isIndex()) 398 name = "sparseIndices"; 399 else if (eltType.isInteger(64)) 400 name = "sparseIndices64"; 401 else if (eltType.isInteger(32)) 402 name = "sparseIndices32"; 403 else if (eltType.isInteger(16)) 404 name = "sparseIndices16"; 405 else if (eltType.isInteger(8)) 406 name = "sparseIndices8"; 407 else 408 return failure(); 409 rewriter.replaceOpWithNewOp<CallOp>( 410 op, resType, 411 getFunc(op, name, resType, operands, /*emitCInterface=*/true), 412 operands); 413 return success(); 414 } 415 }; 416 417 /// Sparse conversion rule for value accesses. 418 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 419 public: 420 using OpConversionPattern::OpConversionPattern; 421 LogicalResult 422 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const override { 424 Type resType = op.getType(); 425 Type eltType = resType.cast<ShapedType>().getElementType(); 426 StringRef name; 427 if (eltType.isF64()) 428 name = "sparseValuesF64"; 429 else if (eltType.isF32()) 430 name = "sparseValuesF32"; 431 else if (eltType.isInteger(64)) 432 name = "sparseValuesI64"; 433 else if (eltType.isInteger(32)) 434 name = "sparseValuesI32"; 435 else if (eltType.isInteger(16)) 436 name = "sparseValuesI16"; 437 else if (eltType.isInteger(8)) 438 name = "sparseValuesI8"; 439 else 440 return failure(); 441 rewriter.replaceOpWithNewOp<CallOp>( 442 op, resType, 443 getFunc(op, name, resType, operands, /*emitCInterface=*/true), 444 operands); 445 return success(); 446 } 447 }; 448 449 /// Sparse conversion rule for tensor reconstruction. 450 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 451 public: 452 using OpConversionPattern::OpConversionPattern; 453 LogicalResult 454 // Simply fold the operator into the pointer to the sparse storage scheme. 455 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 456 ConversionPatternRewriter &rewriter) const override { 457 // Check that all arguments of the tensor reconstruction operators are calls 458 // into the support library that query exactly the same opaque pointer. 459 Value ptr; 460 for (Value op : operands) { 461 if (auto call = op.getDefiningOp<CallOp>()) { 462 Value arg = call.getOperand(0); 463 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 464 return failure(); 465 if (!ptr) 466 ptr = arg; 467 else if (arg != ptr) 468 return failure(); 469 } 470 } 471 // If a single opaque pointer is found, perform the folding. 472 if (!ptr) 473 return failure(); 474 rewriter.replaceOp(op, ptr); 475 return success(); 476 } 477 }; 478 479 } // namespace 480 481 //===----------------------------------------------------------------------===// 482 // Public method for populating conversion rules. 483 //===----------------------------------------------------------------------===// 484 485 /// Populates the given patterns list with conversion rules required for 486 /// the sparsification of linear algebra operations. 487 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 488 RewritePatternSet &patterns) { 489 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 490 SparseTensorNewConverter, SparseTensorConvertConverter, 491 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 492 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 493 typeConverter, patterns.getContext()); 494 } 495