1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 //===----------------------------------------------------------------------===// 31 // Helper methods. 32 //===----------------------------------------------------------------------===// 33 34 /// Returns internal type encoding for primary storage. Keep these 35 /// values consistent with the sparse runtime support library. 36 static unsigned getPrimaryTypeEncoding(Type tp) { 37 if (tp.isF64()) 38 return 1; 39 if (tp.isF32()) 40 return 2; 41 if (tp.isInteger(64)) 42 return 3; 43 if (tp.isInteger(32)) 44 return 4; 45 if (tp.isInteger(16)) 46 return 5; 47 if (tp.isInteger(8)) 48 return 6; 49 return 0; 50 } 51 52 /// Returns internal type encoding for overhead storage. Keep these 53 /// values consistent with the sparse runtime support library. 54 static unsigned getOverheadTypeEncoding(unsigned width) { 55 switch (width) { 56 default: 57 return 1; 58 case 32: 59 return 2; 60 case 16: 61 return 3; 62 case 8: 63 return 4; 64 } 65 } 66 67 /// Returns internal dimension level type encoding. Keep these 68 /// values consistent with the sparse runtime support library. 69 static unsigned 70 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 71 switch (dlt) { 72 case SparseTensorEncodingAttr::DimLevelType::Dense: 73 return 0; 74 case SparseTensorEncodingAttr::DimLevelType::Compressed: 75 return 1; 76 case SparseTensorEncodingAttr::DimLevelType::Singleton: 77 return 2; 78 } 79 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 80 } 81 82 /// Returns integers of given width and values as a constant tensor. 83 /// We cast the static shape into a dynamic shape to ensure that the 84 /// method signature remains uniform accross different tensor dimensions. 85 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 86 Location loc, ArrayRef<APInt> values) { 87 Type etp = rewriter.getIntegerType(width); 88 unsigned sz = values.size(); 89 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 90 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 91 auto elts = 92 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 93 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 94 } 95 96 /// Returns function reference (first hit also inserts into module). 97 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 98 ValueRange operands) { 99 MLIRContext *context = op->getContext(); 100 auto module = op->getParentOfType<ModuleOp>(); 101 auto func = module.lookupSymbol<FuncOp>(name); 102 if (!func) { 103 OpBuilder moduleBuilder(module.getBodyRegion()); 104 moduleBuilder 105 .create<FuncOp>(op->getLoc(), name, 106 FunctionType::get(context, operands.getTypes(), result)) 107 .setPrivate(); 108 } 109 return SymbolRefAttr::get(context, name); 110 } 111 112 /// Generates a call into the "swiss army knife" method of the sparse runtime 113 /// support library for materializing sparse tensors into the computation. 114 static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 115 SparseTensorEncodingAttr &enc, uint32_t action, 116 Value ptr) { 117 Location loc = op->getLoc(); 118 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 119 SmallVector<Value, 8> params; 120 // Sparsity annotations in tensor constant form. 121 SmallVector<APInt, 4> attrs; 122 unsigned sz = enc.getDimLevelType().size(); 123 for (unsigned i = 0; i < sz; i++) 124 attrs.push_back( 125 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 126 params.push_back(getTensor(rewriter, 8, loc, attrs)); 127 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 128 // verification of external data, or for construction of internal data. 129 auto shape = resType.getShape(); 130 SmallVector<APInt, 4> sizes; 131 for (unsigned i = 0; i < sz; i++) { 132 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 133 sizes.push_back(APInt(64, s)); 134 } 135 params.push_back(getTensor(rewriter, 64, loc, sizes)); 136 // Dimension order permutation array. This is the "identity" permutation by 137 // default, or otherwise the "reverse" permutation of a given ordering, so 138 // that indices can be mapped quickly to the right position. 139 SmallVector<APInt, 4> perm(sz); 140 AffineMap p = enc.getDimOrdering(); 141 if (p) { 142 assert(p.isPermutation() && p.getNumResults() == sz); 143 for (unsigned i = 0; i < sz; i++) 144 perm[p.getDimPosition(i)] = APInt(64, i); 145 } else { 146 for (unsigned i = 0; i < sz; i++) 147 perm[i] = APInt(64, i); 148 } 149 params.push_back(getTensor(rewriter, 64, loc, perm)); 150 // Secondary and primary types encoding. 151 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 152 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 153 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 154 assert(primary); 155 params.push_back( 156 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 157 params.push_back( 158 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 159 params.push_back( 160 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 161 // User action and pointer. 162 params.push_back( 163 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action))); 164 params.push_back(ptr); 165 // Generate the call to create new tensor. 166 Type ptrType = 167 LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 168 StringRef name = "newSparseTensor"; 169 rewriter.replaceOpWithNewOp<CallOp>( 170 op, ptrType, getFunc(op, name, ptrType, params), params); 171 } 172 173 /// Generates a call that exposes the data pointer as a void pointer. 174 // TODO: probing the data pointer directly is a bit raw; we should replace 175 // this with proper memref util calls once they become available. 176 static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op, 177 Value val, Value &ptr) { 178 Location loc = op->getLoc(); 179 ShapedType sType = op->getResult(0).getType().cast<ShapedType>(); 180 Type eltType = sType.getElementType(); 181 // Specialize name for the data type. Even though the final buffferized 182 // version only operates on pointers, different names are required to 183 // ensure type correctness for all intermediate states. 184 StringRef name; 185 if (eltType.isF64()) 186 name = "getPtrF64"; 187 else if (eltType.isF32()) 188 name = "getPtrF32"; 189 else if (eltType.isInteger(64)) 190 name = "getPtrI64"; 191 else if (eltType.isInteger(32)) 192 name = "getPtrI32"; 193 else if (eltType.isInteger(16)) 194 name = "getPtrI16"; 195 else if (eltType.isInteger(8)) 196 name = "getPtrI8"; 197 else 198 return false; 199 auto memRefTp = MemRefType::get(sType.getShape(), eltType); 200 auto unrankedTp = UnrankedMemRefType::get(eltType, 0); 201 Value c = rewriter.create<memref::BufferCastOp>(loc, memRefTp, val); 202 Value d = rewriter.create<memref::CastOp>(loc, unrankedTp, c); 203 Type ptrType = 204 LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 205 auto call = 206 rewriter.create<CallOp>(loc, ptrType, getFunc(op, name, ptrType, d), d); 207 ptr = call.getResult(0); 208 return true; 209 } 210 211 //===----------------------------------------------------------------------===// 212 // Conversion rules. 213 //===----------------------------------------------------------------------===// 214 215 /// Sparse conversion rule for returns. 216 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 217 public: 218 using OpConversionPattern::OpConversionPattern; 219 LogicalResult 220 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 221 ConversionPatternRewriter &rewriter) const override { 222 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 223 return success(); 224 } 225 }; 226 227 /// Sparse conversion rule for dimension accesses. 228 class SparseTensorToDimSizeConverter 229 : public OpConversionPattern<tensor::DimOp> { 230 public: 231 using OpConversionPattern::OpConversionPattern; 232 LogicalResult 233 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 234 ConversionPatternRewriter &rewriter) const override { 235 Type resType = op.getType(); 236 auto enc = getSparseTensorEncoding(op.source().getType()); 237 if (!enc) 238 return failure(); 239 // Permute the dim index. 240 Optional<int64_t> index = op.getConstantIndex(); 241 if (!index.hasValue()) 242 return failure(); 243 int64_t idx = index.getValue(); 244 AffineMap p = enc.getDimOrdering(); 245 if (p) { 246 assert(p.isPermutation()); 247 for (unsigned i = 0, sz = p.getNumResults(); i < sz; i++) { 248 if (p.getDimPosition(i) == idx) { 249 idx = i; 250 break; 251 } 252 } 253 } 254 // Generate the call. 255 StringRef name = "sparseDimSize"; 256 SmallVector<Value, 2> params; 257 params.push_back(operands[0]); 258 params.push_back( 259 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx))); 260 rewriter.replaceOpWithNewOp<CallOp>( 261 op, resType, getFunc(op, name, resType, params), params); 262 return success(); 263 } 264 }; 265 266 /// Sparse conversion rule for the new operator. 267 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 268 using OpConversionPattern::OpConversionPattern; 269 LogicalResult 270 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 271 ConversionPatternRewriter &rewriter) const override { 272 Type resType = op.getType(); 273 auto enc = getSparseTensorEncoding(resType); 274 if (!enc) 275 return failure(); 276 genNewCall(rewriter, op, enc, 0, operands[0]); 277 return success(); 278 } 279 }; 280 281 /// Sparse conversion rule for the convert operator. 282 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 283 using OpConversionPattern::OpConversionPattern; 284 LogicalResult 285 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 286 ConversionPatternRewriter &rewriter) const override { 287 Type resType = op.getType(); 288 auto encDst = getSparseTensorEncoding(resType); 289 auto encSrc = getSparseTensorEncoding(op.source().getType()); 290 // TODO: implement sparse => sparse 291 // and sparse => dense 292 if (!encDst || encSrc) 293 return failure(); 294 // This is a dense => sparse conversion. 295 Value ptr; 296 if (!genPtrCall(rewriter, op, operands[0], ptr)) 297 return failure(); 298 genNewCall(rewriter, op, encDst, 1, ptr); 299 return success(); 300 } 301 }; 302 303 /// Sparse conversion rule for pointer accesses. 304 class SparseTensorToPointersConverter 305 : public OpConversionPattern<ToPointersOp> { 306 public: 307 using OpConversionPattern::OpConversionPattern; 308 LogicalResult 309 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 310 ConversionPatternRewriter &rewriter) const override { 311 Type resType = op.getType(); 312 Type eltType = resType.cast<ShapedType>().getElementType(); 313 StringRef name; 314 if (eltType.isIndex()) 315 name = "sparsePointers"; 316 else if (eltType.isInteger(64)) 317 name = "sparsePointers64"; 318 else if (eltType.isInteger(32)) 319 name = "sparsePointers32"; 320 else if (eltType.isInteger(16)) 321 name = "sparsePointers16"; 322 else if (eltType.isInteger(8)) 323 name = "sparsePointers8"; 324 else 325 return failure(); 326 rewriter.replaceOpWithNewOp<CallOp>( 327 op, resType, getFunc(op, name, resType, operands), operands); 328 return success(); 329 } 330 }; 331 332 /// Sparse conversion rule for index accesses. 333 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 334 public: 335 using OpConversionPattern::OpConversionPattern; 336 LogicalResult 337 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 338 ConversionPatternRewriter &rewriter) const override { 339 Type resType = op.getType(); 340 Type eltType = resType.cast<ShapedType>().getElementType(); 341 StringRef name; 342 if (eltType.isIndex()) 343 name = "sparseIndices"; 344 else if (eltType.isInteger(64)) 345 name = "sparseIndices64"; 346 else if (eltType.isInteger(32)) 347 name = "sparseIndices32"; 348 else if (eltType.isInteger(16)) 349 name = "sparseIndices16"; 350 else if (eltType.isInteger(8)) 351 name = "sparseIndices8"; 352 else 353 return failure(); 354 rewriter.replaceOpWithNewOp<CallOp>( 355 op, resType, getFunc(op, name, resType, operands), operands); 356 return success(); 357 } 358 }; 359 360 /// Sparse conversion rule for value accesses. 361 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 362 public: 363 using OpConversionPattern::OpConversionPattern; 364 LogicalResult 365 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 366 ConversionPatternRewriter &rewriter) const override { 367 Type resType = op.getType(); 368 Type eltType = resType.cast<ShapedType>().getElementType(); 369 StringRef name; 370 if (eltType.isF64()) 371 name = "sparseValuesF64"; 372 else if (eltType.isF32()) 373 name = "sparseValuesF32"; 374 else if (eltType.isInteger(64)) 375 name = "sparseValuesI64"; 376 else if (eltType.isInteger(32)) 377 name = "sparseValuesI32"; 378 else if (eltType.isInteger(16)) 379 name = "sparseValuesI16"; 380 else if (eltType.isInteger(8)) 381 name = "sparseValuesI8"; 382 else 383 return failure(); 384 rewriter.replaceOpWithNewOp<CallOp>( 385 op, resType, getFunc(op, name, resType, operands), operands); 386 return success(); 387 } 388 }; 389 390 /// Sparse conversion rule for tensor reconstruction. 391 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 392 public: 393 using OpConversionPattern::OpConversionPattern; 394 LogicalResult 395 // Simply fold the operator into the pointer to the sparse storage scheme. 396 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 397 ConversionPatternRewriter &rewriter) const override { 398 // Check that all arguments of the tensor reconstruction operators are calls 399 // into the support library that query exactly the same opaque pointer. 400 Value ptr; 401 for (Value op : operands) { 402 if (auto call = op.getDefiningOp<CallOp>()) { 403 Value arg = call.getOperand(0); 404 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 405 return failure(); 406 if (!ptr) 407 ptr = arg; 408 else if (arg != ptr) 409 return failure(); 410 } 411 } 412 // If a single opaque pointer is found, perform the folding. 413 if (!ptr) 414 return failure(); 415 rewriter.replaceOp(op, ptr); 416 return success(); 417 } 418 }; 419 420 } // namespace 421 422 //===----------------------------------------------------------------------===// 423 // Public method for populating conversion rules. 424 //===----------------------------------------------------------------------===// 425 426 /// Populates the given patterns list with conversion rules required for 427 /// the sparsification of linear algebra operations. 428 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 429 RewritePatternSet &patterns) { 430 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 431 SparseTensorNewConverter, SparseTensorConvertConverter, 432 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 433 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 434 typeConverter, patterns.getContext()); 435 } 436