1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Returns internal type encoding for primary storage. Keep these 31 /// values consistent with the sparse runtime support library. 32 static unsigned getPrimaryTypeEncoding(Type tp) { 33 if (tp.isF64()) 34 return 1; 35 if (tp.isF32()) 36 return 2; 37 if (tp.isInteger(64)) 38 return 3; 39 if (tp.isInteger(32)) 40 return 4; 41 if (tp.isInteger(16)) 42 return 5; 43 if (tp.isInteger(8)) 44 return 6; 45 return 0; 46 } 47 48 /// Returns internal type encoding for overhead storage. Keep these 49 /// values consistent with the sparse runtime support library. 50 static unsigned getOverheadTypeEncoding(unsigned width) { 51 switch (width) { 52 default: 53 return 1; 54 case 32: 55 return 2; 56 case 16: 57 return 3; 58 case 8: 59 return 4; 60 } 61 } 62 63 /// Returns internal dimension level type encoding. Keep these 64 /// values consistent with the sparse runtime support library. 65 static unsigned 66 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 67 switch (dlt) { 68 case SparseTensorEncodingAttr::DimLevelType::Dense: 69 return 0; 70 case SparseTensorEncodingAttr::DimLevelType::Compressed: 71 return 1; 72 case SparseTensorEncodingAttr::DimLevelType::Singleton: 73 return 2; 74 } 75 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 76 } 77 78 /// Returns integers of given width and values as a constant tensor. 79 /// We cast the static shape into a dynamic shape to ensure that the 80 /// method signature remains uniform accross different tensor dimensions. 81 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 82 Location loc, ArrayRef<APInt> values) { 83 Type etp = rewriter.getIntegerType(width); 84 unsigned sz = values.size(); 85 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 86 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 87 auto elts = 88 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 89 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 90 } 91 92 /// Returns function reference (first hit also inserts into module). 93 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 94 ValueRange operands) { 95 MLIRContext *context = op->getContext(); 96 auto module = op->getParentOfType<ModuleOp>(); 97 auto func = module.lookupSymbol<FuncOp>(name); 98 if (!func) { 99 OpBuilder moduleBuilder(module.getBodyRegion()); 100 moduleBuilder 101 .create<FuncOp>(op->getLoc(), name, 102 FunctionType::get(context, operands.getTypes(), result)) 103 .setPrivate(); 104 } 105 return SymbolRefAttr::get(context, name); 106 } 107 108 /// Sparse conversion rule for returns. 109 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 110 public: 111 using OpConversionPattern::OpConversionPattern; 112 LogicalResult 113 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 114 ConversionPatternRewriter &rewriter) const override { 115 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 116 return success(); 117 } 118 }; 119 120 /// Sparse conversion rule for dimension accesses. 121 class SparseTensorToDimSizeConverter 122 : public OpConversionPattern<tensor::DimOp> { 123 public: 124 using OpConversionPattern::OpConversionPattern; 125 LogicalResult 126 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 127 ConversionPatternRewriter &rewriter) const override { 128 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 129 return failure(); 130 Type resType = op.getType(); 131 StringRef name = "sparseDimSize"; 132 rewriter.replaceOpWithNewOp<CallOp>( 133 op, resType, getFunc(op, name, resType, operands), operands); 134 return success(); 135 } 136 }; 137 138 /// Sparse conversion rule for the new operator. 139 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 140 using OpConversionPattern::OpConversionPattern; 141 LogicalResult 142 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 143 ConversionPatternRewriter &rewriter) const override { 144 Location loc = op.getLoc(); 145 Type resType = op.getType(); 146 Type eltType = resType.cast<ShapedType>().getElementType(); 147 MLIRContext *context = op->getContext(); 148 SmallVector<Value, 5> params; 149 // Sparse encoding. 150 auto enc = getSparseTensorEncoding(resType); 151 if (!enc) 152 return failure(); 153 // User pointer. 154 params.push_back(operands[0]); 155 // Sparsity annotations in tensor constant form. 156 SmallVector<APInt, 4> attrs; 157 unsigned sz = enc.getDimLevelType().size(); 158 for (unsigned i = 0; i < sz; i++) 159 attrs.push_back( 160 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 161 params.push_back(getTensor(rewriter, 8, loc, attrs)); 162 // Dimension order permutation array. This is the "identity" 163 // permutation by default, or otherwise the "reverse" permutation 164 // of a given ordering, so that indices can be mapped quickly 165 // to the right position. 166 SmallVector<APInt, 4> perm(sz); 167 AffineMap p = enc.getDimOrdering(); 168 if (p) { 169 assert(p.isPermutation() && p.getNumResults() == sz); 170 for (unsigned i = 0; i < sz; i++) 171 perm[p.getDimPosition(i)] = APInt(64, i); 172 } else { 173 for (unsigned i = 0; i < sz; i++) 174 perm[i] = APInt(64, i); 175 } 176 params.push_back(getTensor(rewriter, 64, loc, perm)); 177 // Secondary and primary types encoding. 178 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 179 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 180 unsigned primary = getPrimaryTypeEncoding(eltType); 181 if (!primary) 182 return failure(); 183 params.push_back( 184 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 185 params.push_back( 186 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 187 params.push_back( 188 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 189 // Generate the call to create new tensor. 190 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 191 StringRef name = "newSparseTensor"; 192 rewriter.replaceOpWithNewOp<CallOp>( 193 op, ptrType, getFunc(op, name, ptrType, params), params); 194 return success(); 195 } 196 }; 197 198 /// Sparse conversion rule for the convert operator. 199 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 200 using OpConversionPattern::OpConversionPattern; 201 LogicalResult 202 matchAndRewrite(ConvertOp op, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const override { 204 // TODO: implement conversions lowering 205 return failure(); 206 } 207 }; 208 209 /// Sparse conversion rule for pointer accesses. 210 class SparseTensorToPointersConverter 211 : public OpConversionPattern<ToPointersOp> { 212 public: 213 using OpConversionPattern::OpConversionPattern; 214 LogicalResult 215 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 216 ConversionPatternRewriter &rewriter) const override { 217 Type resType = op.getType(); 218 Type eltType = resType.cast<ShapedType>().getElementType(); 219 StringRef name; 220 if (eltType.isIndex()) 221 name = "sparsePointers"; 222 else if (eltType.isInteger(64)) 223 name = "sparsePointers64"; 224 else if (eltType.isInteger(32)) 225 name = "sparsePointers32"; 226 else if (eltType.isInteger(16)) 227 name = "sparsePointers16"; 228 else if (eltType.isInteger(8)) 229 name = "sparsePointers8"; 230 else 231 return failure(); 232 rewriter.replaceOpWithNewOp<CallOp>( 233 op, resType, getFunc(op, name, resType, operands), operands); 234 return success(); 235 } 236 }; 237 238 /// Sparse conversion rule for index accesses. 239 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 240 public: 241 using OpConversionPattern::OpConversionPattern; 242 LogicalResult 243 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 244 ConversionPatternRewriter &rewriter) const override { 245 Type resType = op.getType(); 246 Type eltType = resType.cast<ShapedType>().getElementType(); 247 StringRef name; 248 if (eltType.isIndex()) 249 name = "sparseIndices"; 250 else if (eltType.isInteger(64)) 251 name = "sparseIndices64"; 252 else if (eltType.isInteger(32)) 253 name = "sparseIndices32"; 254 else if (eltType.isInteger(16)) 255 name = "sparseIndices16"; 256 else if (eltType.isInteger(8)) 257 name = "sparseIndices8"; 258 else 259 return failure(); 260 rewriter.replaceOpWithNewOp<CallOp>( 261 op, resType, getFunc(op, name, resType, operands), operands); 262 return success(); 263 } 264 }; 265 266 /// Sparse conversion rule for value accesses. 267 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 268 public: 269 using OpConversionPattern::OpConversionPattern; 270 LogicalResult 271 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 272 ConversionPatternRewriter &rewriter) const override { 273 Type resType = op.getType(); 274 Type eltType = resType.cast<ShapedType>().getElementType(); 275 StringRef name; 276 if (eltType.isF64()) 277 name = "sparseValuesF64"; 278 else if (eltType.isF32()) 279 name = "sparseValuesF32"; 280 else if (eltType.isInteger(64)) 281 name = "sparseValuesI64"; 282 else if (eltType.isInteger(32)) 283 name = "sparseValuesI32"; 284 else if (eltType.isInteger(16)) 285 name = "sparseValuesI16"; 286 else if (eltType.isInteger(8)) 287 name = "sparseValuesI8"; 288 else 289 return failure(); 290 rewriter.replaceOpWithNewOp<CallOp>( 291 op, resType, getFunc(op, name, resType, operands), operands); 292 return success(); 293 } 294 }; 295 296 /// Sparse conversion rule for tensor reconstruction. 297 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 298 public: 299 using OpConversionPattern::OpConversionPattern; 300 LogicalResult 301 // Simply fold the operator into the pointer to the sparse storage scheme. 302 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 303 ConversionPatternRewriter &rewriter) const override { 304 // Check that all arguments of the tensor reconstruction operators are calls 305 // into the support library that query exactly the same opaque pointer. 306 Value ptr; 307 for (Value op : operands) { 308 if (auto call = op.getDefiningOp<CallOp>()) { 309 Value arg = call.getOperand(0); 310 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 311 return failure(); 312 if (!ptr) 313 ptr = arg; 314 else if (arg != ptr) 315 return failure(); 316 } 317 } 318 // If a single opaque pointer is found, perform the folding. 319 if (!ptr) 320 return failure(); 321 rewriter.replaceOp(op, ptr); 322 return success(); 323 } 324 }; 325 326 } // namespace 327 328 /// Populates the given patterns list with conversion rules required for 329 /// the sparsification of linear algebra operations. 330 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 331 RewritePatternSet &patterns) { 332 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 333 SparseTensorNewConverter, SparseTensorConvertConverter, 334 SparseTensorToPointersConverter, SparseTensorToIndicesConverter, 335 SparseTensorToValuesConverter, SparseTensorToTensorConverter>( 336 typeConverter, patterns.getContext()); 337 } 338