1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 //===----------------------------------------------------------------------===// 33 // Helper methods. 34 //===----------------------------------------------------------------------===// 35 36 /// Returns internal type encoding for primary storage. Keep these 37 /// values consistent with the sparse runtime support library. 38 static unsigned getPrimaryTypeEncoding(Type tp) { 39 if (tp.isF64()) 40 return 1; 41 if (tp.isF32()) 42 return 2; 43 if (tp.isInteger(64)) 44 return 3; 45 if (tp.isInteger(32)) 46 return 4; 47 if (tp.isInteger(16)) 48 return 5; 49 if (tp.isInteger(8)) 50 return 6; 51 return 0; 52 } 53 54 /// Returns internal type encoding for overhead storage. Keep these 55 /// values consistent with the sparse runtime support library. 56 static unsigned getOverheadTypeEncoding(unsigned width) { 57 switch (width) { 58 default: 59 return 1; 60 case 32: 61 return 2; 62 case 16: 63 return 3; 64 case 8: 65 return 4; 66 } 67 } 68 69 /// Returns internal dimension level type encoding. Keep these 70 /// values consistent with the sparse runtime support library. 71 static unsigned 72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 73 switch (dlt) { 74 case SparseTensorEncodingAttr::DimLevelType::Dense: 75 return 0; 76 case SparseTensorEncodingAttr::DimLevelType::Compressed: 77 return 1; 78 case SparseTensorEncodingAttr::DimLevelType::Singleton: 79 return 2; 80 } 81 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 82 } 83 84 /// Returns integers of given width and values as a constant tensor. 85 /// We cast the static shape into a dynamic shape to ensure that the 86 /// method signature remains uniform accross different tensor dimensions. 87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 88 Location loc, ArrayRef<APInt> values) { 89 Type etp = rewriter.getIntegerType(width); 90 unsigned sz = values.size(); 91 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 92 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 93 auto elts = 94 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 95 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 96 } 97 98 /// Returns function reference (first hit also inserts into module). 99 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType, 100 ValueRange operands) { 101 MLIRContext *context = op->getContext(); 102 auto module = op->getParentOfType<ModuleOp>(); 103 auto result = SymbolRefAttr::get(context, name); 104 auto func = module.lookupSymbol<FuncOp>(result.getAttr()); 105 if (!func) { 106 OpBuilder moduleBuilder(module.getBodyRegion()); 107 moduleBuilder 108 .create<FuncOp>( 109 op->getLoc(), name, 110 FunctionType::get(context, operands.getTypes(), resultType)) 111 .setPrivate(); 112 } 113 return result; 114 } 115 116 /// Generates a call into the "swiss army knife" method of the sparse runtime 117 /// support library for materializing sparse tensors into the computation. The 118 /// method returns the call value and assigns the permutation to 'perm'. 119 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 120 SparseTensorEncodingAttr &enc, uint32_t action, 121 Value &perm, Value ptr = Value()) { 122 Location loc = op->getLoc(); 123 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 124 SmallVector<Value, 8> params; 125 // Sparsity annotations in tensor constant form. 126 SmallVector<APInt, 4> attrs; 127 unsigned sz = enc.getDimLevelType().size(); 128 for (unsigned i = 0; i < sz; i++) 129 attrs.push_back( 130 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 131 params.push_back(getTensor(rewriter, 8, loc, attrs)); 132 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 133 // verification of external data, or for construction of internal data. 134 auto shape = resType.getShape(); 135 SmallVector<APInt, 4> sizes; 136 for (unsigned i = 0; i < sz; i++) { 137 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 138 sizes.push_back(APInt(64, s)); 139 } 140 params.push_back(getTensor(rewriter, 64, loc, sizes)); 141 // Dimension order permutation array. This is the "identity" permutation by 142 // default, or otherwise the "reverse" permutation of a given ordering, so 143 // that indices can be mapped quickly to the right position. 144 SmallVector<APInt, 4> rev(sz); 145 if (AffineMap p = enc.getDimOrdering()) { 146 for (unsigned i = 0; i < sz; i++) 147 rev[p.getDimPosition(i)] = APInt(64, i); 148 } else { 149 for (unsigned i = 0; i < sz; i++) 150 rev[i] = APInt(64, i); 151 } 152 perm = getTensor(rewriter, 64, loc, rev); 153 params.push_back(perm); 154 // Secondary and primary types encoding. 155 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 156 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 157 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 158 assert(primary); 159 params.push_back( 160 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 161 params.push_back( 162 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 163 params.push_back( 164 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 165 // User action and pointer. 166 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 167 if (!ptr) 168 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 169 params.push_back( 170 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action))); 171 params.push_back(ptr); 172 // Generate the call to create new tensor. 173 StringRef name = "newSparseTensor"; 174 auto call = 175 rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params); 176 return call.getResult(0); 177 } 178 179 /// Generates a call that adds one element to a coordinate scheme. 180 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 181 Value ptr, Value tensor, Value ind, Value perm, 182 ValueRange ivs) { 183 Location loc = op->getLoc(); 184 StringRef name; 185 Type eltType = tensor.getType().cast<ShapedType>().getElementType(); 186 if (eltType.isF64()) 187 name = "addEltF64"; 188 else if (eltType.isF32()) 189 name = "addEltF32"; 190 else if (eltType.isInteger(64)) 191 name = "addEltI64"; 192 else if (eltType.isInteger(32)) 193 name = "addEltI32"; 194 else if (eltType.isInteger(16)) 195 name = "addEltI16"; 196 else if (eltType.isInteger(8)) 197 name = "addEltI8"; 198 else 199 llvm_unreachable("Unknown element type"); 200 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 201 // TODO: add if here? 202 unsigned i = 0; 203 for (auto iv : ivs) { 204 Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++)); 205 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 206 } 207 SmallVector<Value, 8> params; 208 params.push_back(ptr); 209 params.push_back(val); 210 params.push_back(ind); 211 params.push_back(perm); 212 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 213 rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params); 214 } 215 216 //===----------------------------------------------------------------------===// 217 // Conversion rules. 218 //===----------------------------------------------------------------------===// 219 220 /// Sparse conversion rule for returns. 221 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 222 public: 223 using OpConversionPattern::OpConversionPattern; 224 LogicalResult 225 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 226 ConversionPatternRewriter &rewriter) const override { 227 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 228 return success(); 229 } 230 }; 231 232 /// Sparse conversion rule for dimension accesses. 233 class SparseTensorToDimSizeConverter 234 : public OpConversionPattern<tensor::DimOp> { 235 public: 236 using OpConversionPattern::OpConversionPattern; 237 LogicalResult 238 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 239 ConversionPatternRewriter &rewriter) const override { 240 Type resType = op.getType(); 241 auto enc = getSparseTensorEncoding(op.source().getType()); 242 if (!enc) 243 return failure(); 244 // Permute the dim index. 245 Optional<int64_t> index = op.getConstantIndex(); 246 if (!index.hasValue()) 247 return failure(); 248 int64_t idx = index.getValue(); 249 if (AffineMap p = enc.getDimOrdering()) 250 idx = p.getPermutedPosition(idx); 251 // Generate the call. 252 StringRef name = "sparseDimSize"; 253 SmallVector<Value, 2> params; 254 params.push_back(operands[0]); 255 params.push_back( 256 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx))); 257 rewriter.replaceOpWithNewOp<CallOp>( 258 op, resType, getFunc(op, name, resType, params), params); 259 return success(); 260 } 261 }; 262 263 /// Sparse conversion rule for the new operator. 264 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 265 using OpConversionPattern::OpConversionPattern; 266 LogicalResult 267 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 268 ConversionPatternRewriter &rewriter) const override { 269 Type resType = op.getType(); 270 auto enc = getSparseTensorEncoding(resType); 271 if (!enc) 272 return failure(); 273 Value perm; 274 rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); 275 return success(); 276 } 277 }; 278 279 /// Sparse conversion rule for the convert operator. 280 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 281 using OpConversionPattern::OpConversionPattern; 282 LogicalResult 283 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 284 ConversionPatternRewriter &rewriter) const override { 285 Type resType = op.getType(); 286 auto encDst = getSparseTensorEncoding(resType); 287 auto encSrc = getSparseTensorEncoding(op.source().getType()); 288 if (encDst && encSrc) { 289 // This is a sparse => sparse conversion, which is handled as follows: 290 // t = src->asCOO(); ; src to COO in dst order 291 // dst = newSparseTensor(t) 292 // Using the coordinate scheme as an intermediate does not always 293 // yield the fastest conversion but avoids the need for a full 294 // O(N^2) conversion matrix. 295 Value perm; 296 Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]); 297 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); 298 return success(); 299 } 300 if (!encDst || encSrc) { 301 // TODO: sparse => dense 302 return failure(); 303 } 304 // This is a dense => sparse conversion, which is handled as follows: 305 // t = newSparseCOO() 306 // for i1 in dim1 307 // .. 308 // for ik in dimk 309 // val = a[i1,..,ik] 310 // if val != 0 311 // t->add(val, [i1,..,ik], [p1,..,pk]) 312 // s = newSparseTensor(t) 313 // Note that the dense tensor traversal code is actually implemented 314 // using MLIR IR to avoid having to expose too much low-level 315 // memref traversal details to the runtime support library. 316 Location loc = op->getLoc(); 317 ShapedType shape = resType.cast<ShapedType>(); 318 auto memTp = 319 MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); 320 Value perm; 321 Value ptr = genNewCall(rewriter, op, encDst, 2, perm); 322 Value tensor = operands[0]; 323 Value arg = rewriter.create<ConstantOp>( 324 loc, rewriter.getIndexAttr(shape.getRank())); 325 Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg}); 326 SmallVector<Value> lo; 327 SmallVector<Value> hi; 328 SmallVector<Value> st; 329 Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0)); 330 Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1)); 331 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 332 lo.push_back(zero); 333 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); 334 st.push_back(one); 335 } 336 scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {}, 337 [&](OpBuilder &builder, Location loc, ValueRange ivs, 338 ValueRange args) -> scf::ValueVector { 339 genAddEltCall(rewriter, op, ptr, tensor, ind, perm, 340 ivs); 341 return {}; 342 }); 343 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); 344 return success(); 345 } 346 }; 347 348 /// Sparse conversion rule for pointer accesses. 349 class SparseTensorToPointersConverter 350 : public OpConversionPattern<ToPointersOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 LogicalResult 354 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const override { 356 Type resType = op.getType(); 357 Type eltType = resType.cast<ShapedType>().getElementType(); 358 StringRef name; 359 if (eltType.isIndex()) 360 name = "sparsePointers"; 361 else if (eltType.isInteger(64)) 362 name = "sparsePointers64"; 363 else if (eltType.isInteger(32)) 364 name = "sparsePointers32"; 365 else if (eltType.isInteger(16)) 366 name = "sparsePointers16"; 367 else if (eltType.isInteger(8)) 368 name = "sparsePointers8"; 369 else 370 return failure(); 371 rewriter.replaceOpWithNewOp<CallOp>( 372 op, resType, getFunc(op, name, resType, operands), operands); 373 return success(); 374 } 375 }; 376 377 /// Sparse conversion rule for index accesses. 378 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 379 public: 380 using OpConversionPattern::OpConversionPattern; 381 LogicalResult 382 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const override { 384 Type resType = op.getType(); 385 Type eltType = resType.cast<ShapedType>().getElementType(); 386 StringRef name; 387 if (eltType.isIndex()) 388 name = "sparseIndices"; 389 else if (eltType.isInteger(64)) 390 name = "sparseIndices64"; 391 else if (eltType.isInteger(32)) 392 name = "sparseIndices32"; 393 else if (eltType.isInteger(16)) 394 name = "sparseIndices16"; 395 else if (eltType.isInteger(8)) 396 name = "sparseIndices8"; 397 else 398 return failure(); 399 rewriter.replaceOpWithNewOp<CallOp>( 400 op, resType, getFunc(op, name, resType, operands), operands); 401 return success(); 402 } 403 }; 404 405 /// Sparse conversion rule for value accesses. 406 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 407 public: 408 using OpConversionPattern::OpConversionPattern; 409 LogicalResult 410 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const override { 412 Type resType = op.getType(); 413 Type eltType = resType.cast<ShapedType>().getElementType(); 414 StringRef name; 415 if (eltType.isF64()) 416 name = "sparseValuesF64"; 417 else if (eltType.isF32()) 418 name = "sparseValuesF32"; 419 else if (eltType.isInteger(64)) 420 name = "sparseValuesI64"; 421 else if (eltType.isInteger(32)) 422 name = "sparseValuesI32"; 423 else if (eltType.isInteger(16)) 424 name = "sparseValuesI16"; 425 else if (eltType.isInteger(8)) 426 name = "sparseValuesI8"; 427 else 428 return failure(); 429 rewriter.replaceOpWithNewOp<CallOp>( 430 op, resType, getFunc(op, name, resType, operands), operands); 431 return success(); 432 } 433 }; 434 435 /// Sparse conversion rule for tensor reconstruction. 436 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 437 public: 438 using OpConversionPattern::OpConversionPattern; 439 LogicalResult 440 // Simply fold the operator into the pointer to the sparse storage scheme. 441 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 442 ConversionPatternRewriter &rewriter) const override { 443 // Check that all arguments of the tensor reconstruction operators are calls 444 // into the support library that query exactly the same opaque pointer. 445 Value ptr; 446 for (Value op : operands) { 447 if (auto call = op.getDefiningOp<CallOp>()) { 448 Value arg = call.getOperand(0); 449 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 450 return failure(); 451 if (!ptr) 452 ptr = arg; 453 else if (arg != ptr) 454 return failure(); 455 } 456 } 457 // If a single opaque pointer is found, perform the folding. 458 if (!ptr) 459 return failure(); 460 rewriter.replaceOp(op, ptr); 461 return success(); 462 } 463 }; 464 465 } // namespace 466 467 //===----------------------------------------------------------------------===// 468 // Public method for populating conversion rules. 469 //===----------------------------------------------------------------------===// 470 471 /// Populates the given patterns list with conversion rules required for 472 /// the sparsification of linear algebra operations. 473 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 474 RewritePatternSet &patterns) { 475 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 476 SparseTensorNewConverter, SparseTensorConvertConverter, 477 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 478 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 479 typeConverter, patterns.getContext()); 480 } 481