1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 27 using namespace mlir; 28 using namespace mlir::sparse_tensor; 29 30 namespace { 31 32 //===----------------------------------------------------------------------===// 33 // Helper methods. 34 //===----------------------------------------------------------------------===// 35 36 /// Returns internal type encoding for primary storage. Keep these 37 /// values consistent with the sparse runtime support library. 38 static unsigned getPrimaryTypeEncoding(Type tp) { 39 if (tp.isF64()) 40 return 1; 41 if (tp.isF32()) 42 return 2; 43 if (tp.isInteger(64)) 44 return 3; 45 if (tp.isInteger(32)) 46 return 4; 47 if (tp.isInteger(16)) 48 return 5; 49 if (tp.isInteger(8)) 50 return 6; 51 return 0; 52 } 53 54 /// Returns internal type encoding for overhead storage. Keep these 55 /// values consistent with the sparse runtime support library. 56 static unsigned getOverheadTypeEncoding(unsigned width) { 57 switch (width) { 58 default: 59 return 1; 60 case 32: 61 return 2; 62 case 16: 63 return 3; 64 case 8: 65 return 4; 66 } 67 } 68 69 /// Returns internal dimension level type encoding. Keep these 70 /// values consistent with the sparse runtime support library. 71 static unsigned 72 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 73 switch (dlt) { 74 case SparseTensorEncodingAttr::DimLevelType::Dense: 75 return 0; 76 case SparseTensorEncodingAttr::DimLevelType::Compressed: 77 return 1; 78 case SparseTensorEncodingAttr::DimLevelType::Singleton: 79 return 2; 80 } 81 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 82 } 83 84 /// Returns integers of given width and values as a constant tensor. 85 /// We cast the static shape into a dynamic shape to ensure that the 86 /// method signature remains uniform accross different tensor dimensions. 87 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 88 Location loc, ArrayRef<APInt> values) { 89 Type etp = rewriter.getIntegerType(width); 90 unsigned sz = values.size(); 91 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 92 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 93 auto elts = 94 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 95 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 96 } 97 98 /// Returns function reference (first hit also inserts into module). 99 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 100 ValueRange operands) { 101 MLIRContext *context = op->getContext(); 102 auto module = op->getParentOfType<ModuleOp>(); 103 auto func = module.lookupSymbol<FuncOp>(name); 104 if (!func) { 105 OpBuilder moduleBuilder(module.getBodyRegion()); 106 moduleBuilder 107 .create<FuncOp>(op->getLoc(), name, 108 FunctionType::get(context, operands.getTypes(), result)) 109 .setPrivate(); 110 } 111 return SymbolRefAttr::get(context, name); 112 } 113 114 /// Generates a call into the "swiss army knife" method of the sparse runtime 115 /// support library for materializing sparse tensors into the computation. The 116 /// method returns the call value and assigns the permutation to 'perm'. 117 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 118 SparseTensorEncodingAttr &enc, uint32_t action, 119 Value &perm, Value ptr = Value()) { 120 Location loc = op->getLoc(); 121 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 122 SmallVector<Value, 8> params; 123 // Sparsity annotations in tensor constant form. 124 SmallVector<APInt, 4> attrs; 125 unsigned sz = enc.getDimLevelType().size(); 126 for (unsigned i = 0; i < sz; i++) 127 attrs.push_back( 128 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 129 params.push_back(getTensor(rewriter, 8, loc, attrs)); 130 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 131 // verification of external data, or for construction of internal data. 132 auto shape = resType.getShape(); 133 SmallVector<APInt, 4> sizes; 134 for (unsigned i = 0; i < sz; i++) { 135 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 136 sizes.push_back(APInt(64, s)); 137 } 138 params.push_back(getTensor(rewriter, 64, loc, sizes)); 139 // Dimension order permutation array. This is the "identity" permutation by 140 // default, or otherwise the "reverse" permutation of a given ordering, so 141 // that indices can be mapped quickly to the right position. 142 SmallVector<APInt, 4> rev(sz); 143 if (AffineMap p = enc.getDimOrdering()) { 144 for (unsigned i = 0; i < sz; i++) 145 rev[p.getDimPosition(i)] = APInt(64, i); 146 } else { 147 for (unsigned i = 0; i < sz; i++) 148 rev[i] = APInt(64, i); 149 } 150 perm = getTensor(rewriter, 64, loc, rev); 151 params.push_back(perm); 152 // Secondary and primary types encoding. 153 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 154 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 155 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 156 assert(primary); 157 params.push_back( 158 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 159 params.push_back( 160 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 161 params.push_back( 162 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 163 // User action and pointer. 164 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 165 if (!ptr) 166 ptr = rewriter.create<LLVM::NullOp>(loc, pTp); 167 params.push_back( 168 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action))); 169 params.push_back(ptr); 170 // Generate the call to create new tensor. 171 StringRef name = "newSparseTensor"; 172 auto call = 173 rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params); 174 return call.getResult(0); 175 } 176 177 /// Generates a call that adds one element to a coordinate scheme. 178 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, 179 Value ptr, Value tensor, Value ind, Value perm, 180 ValueRange ivs) { 181 Location loc = op->getLoc(); 182 StringRef name; 183 Type eltType = tensor.getType().cast<ShapedType>().getElementType(); 184 if (eltType.isF64()) 185 name = "addEltF64"; 186 else if (eltType.isF32()) 187 name = "addEltF32"; 188 else if (eltType.isInteger(64)) 189 name = "addEltI64"; 190 else if (eltType.isInteger(32)) 191 name = "addEltI32"; 192 else if (eltType.isInteger(16)) 193 name = "addEltI16"; 194 else if (eltType.isInteger(8)) 195 name = "addEltI8"; 196 else 197 llvm_unreachable("Unknown element type"); 198 Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs); 199 // TODO: add if here? 200 unsigned i = 0; 201 for (auto iv : ivs) { 202 Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++)); 203 rewriter.create<memref::StoreOp>(loc, iv, ind, idx); 204 } 205 SmallVector<Value, 8> params; 206 params.push_back(ptr); 207 params.push_back(val); 208 params.push_back(ind); 209 params.push_back(perm); 210 Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 211 rewriter.create<CallOp>(loc, pTp, getFunc(op, name, pTp, params), params); 212 } 213 214 //===----------------------------------------------------------------------===// 215 // Conversion rules. 216 //===----------------------------------------------------------------------===// 217 218 /// Sparse conversion rule for returns. 219 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 220 public: 221 using OpConversionPattern::OpConversionPattern; 222 LogicalResult 223 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 224 ConversionPatternRewriter &rewriter) const override { 225 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 226 return success(); 227 } 228 }; 229 230 /// Sparse conversion rule for dimension accesses. 231 class SparseTensorToDimSizeConverter 232 : public OpConversionPattern<tensor::DimOp> { 233 public: 234 using OpConversionPattern::OpConversionPattern; 235 LogicalResult 236 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 237 ConversionPatternRewriter &rewriter) const override { 238 Type resType = op.getType(); 239 auto enc = getSparseTensorEncoding(op.source().getType()); 240 if (!enc) 241 return failure(); 242 // Permute the dim index. 243 Optional<int64_t> index = op.getConstantIndex(); 244 if (!index.hasValue()) 245 return failure(); 246 int64_t idx = index.getValue(); 247 if (AffineMap p = enc.getDimOrdering()) 248 idx = p.getPermutedPosition(idx); 249 // Generate the call. 250 StringRef name = "sparseDimSize"; 251 SmallVector<Value, 2> params; 252 params.push_back(operands[0]); 253 params.push_back( 254 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx))); 255 rewriter.replaceOpWithNewOp<CallOp>( 256 op, resType, getFunc(op, name, resType, params), params); 257 return success(); 258 } 259 }; 260 261 /// Sparse conversion rule for the new operator. 262 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 263 using OpConversionPattern::OpConversionPattern; 264 LogicalResult 265 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 266 ConversionPatternRewriter &rewriter) const override { 267 Type resType = op.getType(); 268 auto enc = getSparseTensorEncoding(resType); 269 if (!enc) 270 return failure(); 271 Value perm; 272 rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); 273 return success(); 274 } 275 }; 276 277 /// Sparse conversion rule for the convert operator. 278 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 279 using OpConversionPattern::OpConversionPattern; 280 LogicalResult 281 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 282 ConversionPatternRewriter &rewriter) const override { 283 Type resType = op.getType(); 284 auto encDst = getSparseTensorEncoding(resType); 285 auto encSrc = getSparseTensorEncoding(op.source().getType()); 286 // TODO: implement sparse => sparse 287 // and sparse => dense 288 if (!encDst || encSrc) 289 return failure(); 290 // This is a dense => sparse conversion, that is handled as follows: 291 // t = newSparseCOO() 292 // for i1 in dim1 293 // .. 294 // for ik in dimk 295 // val = a[i1,..,ik] 296 // if val != 0 297 // t->add(val, [i1,..,ik], [p1,..,pk]) 298 // s = newSparseTensor(t) 299 // Note that the dense tensor traversal code is actually implemented 300 // using MLIR IR to avoid having to expose too much low-level 301 // memref traversal details to the runtime support library. 302 Location loc = op->getLoc(); 303 ShapedType shape = resType.cast<ShapedType>(); 304 auto memTp = 305 MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); 306 Value perm; 307 Value ptr = genNewCall(rewriter, op, encDst, 2, perm); 308 Value tensor = operands[0]; 309 Value arg = rewriter.create<ConstantOp>( 310 loc, rewriter.getIndexAttr(shape.getRank())); 311 Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg}); 312 SmallVector<Value> lo; 313 SmallVector<Value> hi; 314 SmallVector<Value> st; 315 Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0)); 316 Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1)); 317 for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { 318 lo.push_back(zero); 319 hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); 320 st.push_back(one); 321 } 322 scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {}, 323 [&](OpBuilder &builder, Location loc, ValueRange ivs, 324 ValueRange args) -> scf::ValueVector { 325 genAddEltCall(rewriter, op, ptr, tensor, ind, perm, 326 ivs); 327 return {}; 328 }); 329 rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); 330 return success(); 331 } 332 }; 333 334 /// Sparse conversion rule for pointer accesses. 335 class SparseTensorToPointersConverter 336 : public OpConversionPattern<ToPointersOp> { 337 public: 338 using OpConversionPattern::OpConversionPattern; 339 LogicalResult 340 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 341 ConversionPatternRewriter &rewriter) const override { 342 Type resType = op.getType(); 343 Type eltType = resType.cast<ShapedType>().getElementType(); 344 StringRef name; 345 if (eltType.isIndex()) 346 name = "sparsePointers"; 347 else if (eltType.isInteger(64)) 348 name = "sparsePointers64"; 349 else if (eltType.isInteger(32)) 350 name = "sparsePointers32"; 351 else if (eltType.isInteger(16)) 352 name = "sparsePointers16"; 353 else if (eltType.isInteger(8)) 354 name = "sparsePointers8"; 355 else 356 return failure(); 357 rewriter.replaceOpWithNewOp<CallOp>( 358 op, resType, getFunc(op, name, resType, operands), operands); 359 return success(); 360 } 361 }; 362 363 /// Sparse conversion rule for index accesses. 364 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 365 public: 366 using OpConversionPattern::OpConversionPattern; 367 LogicalResult 368 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 369 ConversionPatternRewriter &rewriter) const override { 370 Type resType = op.getType(); 371 Type eltType = resType.cast<ShapedType>().getElementType(); 372 StringRef name; 373 if (eltType.isIndex()) 374 name = "sparseIndices"; 375 else if (eltType.isInteger(64)) 376 name = "sparseIndices64"; 377 else if (eltType.isInteger(32)) 378 name = "sparseIndices32"; 379 else if (eltType.isInteger(16)) 380 name = "sparseIndices16"; 381 else if (eltType.isInteger(8)) 382 name = "sparseIndices8"; 383 else 384 return failure(); 385 rewriter.replaceOpWithNewOp<CallOp>( 386 op, resType, getFunc(op, name, resType, operands), operands); 387 return success(); 388 } 389 }; 390 391 /// Sparse conversion rule for value accesses. 392 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 393 public: 394 using OpConversionPattern::OpConversionPattern; 395 LogicalResult 396 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 397 ConversionPatternRewriter &rewriter) const override { 398 Type resType = op.getType(); 399 Type eltType = resType.cast<ShapedType>().getElementType(); 400 StringRef name; 401 if (eltType.isF64()) 402 name = "sparseValuesF64"; 403 else if (eltType.isF32()) 404 name = "sparseValuesF32"; 405 else if (eltType.isInteger(64)) 406 name = "sparseValuesI64"; 407 else if (eltType.isInteger(32)) 408 name = "sparseValuesI32"; 409 else if (eltType.isInteger(16)) 410 name = "sparseValuesI16"; 411 else if (eltType.isInteger(8)) 412 name = "sparseValuesI8"; 413 else 414 return failure(); 415 rewriter.replaceOpWithNewOp<CallOp>( 416 op, resType, getFunc(op, name, resType, operands), operands); 417 return success(); 418 } 419 }; 420 421 /// Sparse conversion rule for tensor reconstruction. 422 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 423 public: 424 using OpConversionPattern::OpConversionPattern; 425 LogicalResult 426 // Simply fold the operator into the pointer to the sparse storage scheme. 427 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const override { 429 // Check that all arguments of the tensor reconstruction operators are calls 430 // into the support library that query exactly the same opaque pointer. 431 Value ptr; 432 for (Value op : operands) { 433 if (auto call = op.getDefiningOp<CallOp>()) { 434 Value arg = call.getOperand(0); 435 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 436 return failure(); 437 if (!ptr) 438 ptr = arg; 439 else if (arg != ptr) 440 return failure(); 441 } 442 } 443 // If a single opaque pointer is found, perform the folding. 444 if (!ptr) 445 return failure(); 446 rewriter.replaceOp(op, ptr); 447 return success(); 448 } 449 }; 450 451 } // namespace 452 453 //===----------------------------------------------------------------------===// 454 // Public method for populating conversion rules. 455 //===----------------------------------------------------------------------===// 456 457 /// Populates the given patterns list with conversion rules required for 458 /// the sparsification of linear algebra operations. 459 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 460 RewritePatternSet &patterns) { 461 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 462 SparseTensorNewConverter, SparseTensorConvertConverter, 463 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 464 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 465 typeConverter, patterns.getContext()); 466 } 467