1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 //===----------------------------------------------------------------------===// 31 // Helper methods. 32 //===----------------------------------------------------------------------===// 33 34 /// Returns internal type encoding for primary storage. Keep these 35 /// values consistent with the sparse runtime support library. 36 static unsigned getPrimaryTypeEncoding(Type tp) { 37 if (tp.isF64()) 38 return 1; 39 if (tp.isF32()) 40 return 2; 41 if (tp.isInteger(64)) 42 return 3; 43 if (tp.isInteger(32)) 44 return 4; 45 if (tp.isInteger(16)) 46 return 5; 47 if (tp.isInteger(8)) 48 return 6; 49 return 0; 50 } 51 52 /// Returns internal type encoding for overhead storage. Keep these 53 /// values consistent with the sparse runtime support library. 54 static unsigned getOverheadTypeEncoding(unsigned width) { 55 switch (width) { 56 default: 57 return 1; 58 case 32: 59 return 2; 60 case 16: 61 return 3; 62 case 8: 63 return 4; 64 } 65 } 66 67 /// Returns internal dimension level type encoding. Keep these 68 /// values consistent with the sparse runtime support library. 69 static unsigned 70 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 71 switch (dlt) { 72 case SparseTensorEncodingAttr::DimLevelType::Dense: 73 return 0; 74 case SparseTensorEncodingAttr::DimLevelType::Compressed: 75 return 1; 76 case SparseTensorEncodingAttr::DimLevelType::Singleton: 77 return 2; 78 } 79 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 80 } 81 82 /// Returns integers of given width and values as a constant tensor. 83 /// We cast the static shape into a dynamic shape to ensure that the 84 /// method signature remains uniform accross different tensor dimensions. 85 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 86 Location loc, ArrayRef<APInt> values) { 87 Type etp = rewriter.getIntegerType(width); 88 unsigned sz = values.size(); 89 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 90 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 91 auto elts = 92 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 93 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 94 } 95 96 /// Returns function reference (first hit also inserts into module). 97 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 98 ValueRange operands) { 99 MLIRContext *context = op->getContext(); 100 auto module = op->getParentOfType<ModuleOp>(); 101 auto func = module.lookupSymbol<FuncOp>(name); 102 if (!func) { 103 OpBuilder moduleBuilder(module.getBodyRegion()); 104 moduleBuilder 105 .create<FuncOp>(op->getLoc(), name, 106 FunctionType::get(context, operands.getTypes(), result)) 107 .setPrivate(); 108 } 109 return SymbolRefAttr::get(context, name); 110 } 111 112 /// Generates a call into the "swiss army knife" method of the sparse runtime 113 /// support library for materializing sparse tensors into the computation. 114 static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op, 115 SparseTensorEncodingAttr &enc, uint32_t action, 116 Value ptr) { 117 Location loc = op->getLoc(); 118 ShapedType resType = op->getResult(0).getType().cast<ShapedType>(); 119 SmallVector<Value, 8> params; 120 // Sparsity annotations in tensor constant form. 121 SmallVector<APInt, 4> attrs; 122 unsigned sz = enc.getDimLevelType().size(); 123 for (unsigned i = 0; i < sz; i++) 124 attrs.push_back( 125 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 126 params.push_back(getTensor(rewriter, 8, loc, attrs)); 127 // Dimension sizes array of the enveloping *dense* tensor. Useful for either 128 // verification of external data, or for construction of internal data. 129 auto shape = resType.getShape(); 130 SmallVector<APInt, 4> sizes; 131 for (unsigned i = 0; i < sz; i++) { 132 uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; 133 sizes.push_back(APInt(64, s)); 134 } 135 params.push_back(getTensor(rewriter, 64, loc, sizes)); 136 // Dimension order permutation array. This is the "identity" permutation by 137 // default, or otherwise the "reverse" permutation of a given ordering, so 138 // that indices can be mapped quickly to the right position. 139 SmallVector<APInt, 4> perm(sz); 140 AffineMap p = enc.getDimOrdering(); 141 if (p) { 142 assert(p.isPermutation() && p.getNumResults() == sz); 143 for (unsigned i = 0; i < sz; i++) 144 perm[p.getDimPosition(i)] = APInt(64, i); 145 } else { 146 for (unsigned i = 0; i < sz; i++) 147 perm[i] = APInt(64, i); 148 } 149 params.push_back(getTensor(rewriter, 64, loc, perm)); 150 // Secondary and primary types encoding. 151 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 152 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 153 unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); 154 assert(primary); 155 params.push_back( 156 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 157 params.push_back( 158 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 159 params.push_back( 160 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 161 // User action and pointer. 162 params.push_back( 163 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(action))); 164 params.push_back(ptr); 165 // Generate the call to create new tensor. 166 Type ptrType = 167 LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 168 StringRef name = "newSparseTensor"; 169 rewriter.replaceOpWithNewOp<CallOp>( 170 op, ptrType, getFunc(op, name, ptrType, params), params); 171 } 172 173 /// Generates a call that exposes the data pointer as a void pointer. 174 // TODO: probing the data pointer directly is a bit raw; we should replace 175 // this with proper memref util calls once they become available. 176 static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op, 177 Value val, Value &ptr) { 178 Location loc = op->getLoc(); 179 ShapedType sType = op->getResult(0).getType().cast<ShapedType>(); 180 Type eltType = sType.getElementType(); 181 // Specialize name for the data type. Even though the final buffferized 182 // version only operates on pointers, different names are required to 183 // ensure type correctness for all intermediate states. 184 StringRef name; 185 if (eltType.isF64()) 186 name = "getPtrF64"; 187 else if (eltType.isF32()) 188 name = "getPtrF32"; 189 else if (eltType.isInteger(64)) 190 name = "getPtrI64"; 191 else if (eltType.isInteger(32)) 192 name = "getPtrI32"; 193 else if (eltType.isInteger(16)) 194 name = "getPtrI16"; 195 else if (eltType.isInteger(8)) 196 name = "getPtrI8"; 197 else 198 return false; 199 auto memRefTp = MemRefType::get(sType.getShape(), eltType); 200 auto unrankedTp = UnrankedMemRefType::get(eltType, 0); 201 Value c = rewriter.create<memref::BufferCastOp>(loc, memRefTp, val); 202 Value d = rewriter.create<memref::CastOp>(loc, unrankedTp, c); 203 Type ptrType = 204 LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); 205 auto call = 206 rewriter.create<CallOp>(loc, ptrType, getFunc(op, name, ptrType, d), d); 207 ptr = call.getResult(0); 208 return true; 209 } 210 211 //===----------------------------------------------------------------------===// 212 // Conversion rules. 213 //===----------------------------------------------------------------------===// 214 215 /// Sparse conversion rule for returns. 216 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 217 public: 218 using OpConversionPattern::OpConversionPattern; 219 LogicalResult 220 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 221 ConversionPatternRewriter &rewriter) const override { 222 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 223 return success(); 224 } 225 }; 226 227 /// Sparse conversion rule for dimension accesses. 228 class SparseTensorToDimSizeConverter 229 : public OpConversionPattern<tensor::DimOp> { 230 public: 231 using OpConversionPattern::OpConversionPattern; 232 LogicalResult 233 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 234 ConversionPatternRewriter &rewriter) const override { 235 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 236 return failure(); 237 Type resType = op.getType(); 238 StringRef name = "sparseDimSize"; 239 rewriter.replaceOpWithNewOp<CallOp>( 240 op, resType, getFunc(op, name, resType, operands), operands); 241 return success(); 242 } 243 }; 244 245 /// Sparse conversion rule for the new operator. 246 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 247 using OpConversionPattern::OpConversionPattern; 248 LogicalResult 249 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 250 ConversionPatternRewriter &rewriter) const override { 251 Type resType = op.getType(); 252 auto enc = getSparseTensorEncoding(resType); 253 if (!enc) 254 return failure(); 255 genNewCall(rewriter, op, enc, 0, operands[0]); 256 return success(); 257 } 258 }; 259 260 /// Sparse conversion rule for the convert operator. 261 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 262 using OpConversionPattern::OpConversionPattern; 263 LogicalResult 264 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 265 ConversionPatternRewriter &rewriter) const override { 266 Type resType = op.getType(); 267 auto encDst = getSparseTensorEncoding(resType); 268 auto encSrc = getSparseTensorEncoding(op.source().getType()); 269 // TODO: implement sparse => sparse 270 // and sparse => dense 271 if (!encDst || encSrc) 272 return failure(); 273 // This is a dense => sparse conversion. 274 Value ptr; 275 if (!genPtrCall(rewriter, op, operands[0], ptr)) 276 return failure(); 277 genNewCall(rewriter, op, encDst, 1, ptr); 278 return success(); 279 } 280 }; 281 282 /// Sparse conversion rule for pointer accesses. 283 class SparseTensorToPointersConverter 284 : public OpConversionPattern<ToPointersOp> { 285 public: 286 using OpConversionPattern::OpConversionPattern; 287 LogicalResult 288 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 289 ConversionPatternRewriter &rewriter) const override { 290 Type resType = op.getType(); 291 Type eltType = resType.cast<ShapedType>().getElementType(); 292 StringRef name; 293 if (eltType.isIndex()) 294 name = "sparsePointers"; 295 else if (eltType.isInteger(64)) 296 name = "sparsePointers64"; 297 else if (eltType.isInteger(32)) 298 name = "sparsePointers32"; 299 else if (eltType.isInteger(16)) 300 name = "sparsePointers16"; 301 else if (eltType.isInteger(8)) 302 name = "sparsePointers8"; 303 else 304 return failure(); 305 rewriter.replaceOpWithNewOp<CallOp>( 306 op, resType, getFunc(op, name, resType, operands), operands); 307 return success(); 308 } 309 }; 310 311 /// Sparse conversion rule for index accesses. 312 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 313 public: 314 using OpConversionPattern::OpConversionPattern; 315 LogicalResult 316 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 317 ConversionPatternRewriter &rewriter) const override { 318 Type resType = op.getType(); 319 Type eltType = resType.cast<ShapedType>().getElementType(); 320 StringRef name; 321 if (eltType.isIndex()) 322 name = "sparseIndices"; 323 else if (eltType.isInteger(64)) 324 name = "sparseIndices64"; 325 else if (eltType.isInteger(32)) 326 name = "sparseIndices32"; 327 else if (eltType.isInteger(16)) 328 name = "sparseIndices16"; 329 else if (eltType.isInteger(8)) 330 name = "sparseIndices8"; 331 else 332 return failure(); 333 rewriter.replaceOpWithNewOp<CallOp>( 334 op, resType, getFunc(op, name, resType, operands), operands); 335 return success(); 336 } 337 }; 338 339 /// Sparse conversion rule for value accesses. 340 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 341 public: 342 using OpConversionPattern::OpConversionPattern; 343 LogicalResult 344 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override { 346 Type resType = op.getType(); 347 Type eltType = resType.cast<ShapedType>().getElementType(); 348 StringRef name; 349 if (eltType.isF64()) 350 name = "sparseValuesF64"; 351 else if (eltType.isF32()) 352 name = "sparseValuesF32"; 353 else if (eltType.isInteger(64)) 354 name = "sparseValuesI64"; 355 else if (eltType.isInteger(32)) 356 name = "sparseValuesI32"; 357 else if (eltType.isInteger(16)) 358 name = "sparseValuesI16"; 359 else if (eltType.isInteger(8)) 360 name = "sparseValuesI8"; 361 else 362 return failure(); 363 rewriter.replaceOpWithNewOp<CallOp>( 364 op, resType, getFunc(op, name, resType, operands), operands); 365 return success(); 366 } 367 }; 368 369 /// Sparse conversion rule for tensor reconstruction. 370 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 371 public: 372 using OpConversionPattern::OpConversionPattern; 373 LogicalResult 374 // Simply fold the operator into the pointer to the sparse storage scheme. 375 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 376 ConversionPatternRewriter &rewriter) const override { 377 // Check that all arguments of the tensor reconstruction operators are calls 378 // into the support library that query exactly the same opaque pointer. 379 Value ptr; 380 for (Value op : operands) { 381 if (auto call = op.getDefiningOp<CallOp>()) { 382 Value arg = call.getOperand(0); 383 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 384 return failure(); 385 if (!ptr) 386 ptr = arg; 387 else if (arg != ptr) 388 return failure(); 389 } 390 } 391 // If a single opaque pointer is found, perform the folding. 392 if (!ptr) 393 return failure(); 394 rewriter.replaceOp(op, ptr); 395 return success(); 396 } 397 }; 398 399 } // namespace 400 401 //===----------------------------------------------------------------------===// 402 // Public method for populating conversion rules. 403 //===----------------------------------------------------------------------===// 404 405 /// Populates the given patterns list with conversion rules required for 406 /// the sparsification of linear algebra operations. 407 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 408 RewritePatternSet &patterns) { 409 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 410 SparseTensorNewConverter, SparseTensorConvertConverter, 411 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 412 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 413 typeConverter, patterns.getContext()); 414 } 415