1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Returns internal type encoding for overhead storage. 31 static unsigned getOverheadTypeEncoding(unsigned width) { 32 switch (width) { 33 default: 34 return 1; 35 case 32: 36 return 2; 37 case 16: 38 return 3; 39 case 8: 40 return 4; 41 } 42 } 43 44 /// Returns internal dimension level type encoding. 45 static unsigned 46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 47 switch (dlt) { 48 case SparseTensorEncodingAttr::DimLevelType::Dense: 49 return 0; 50 case SparseTensorEncodingAttr::DimLevelType::Compressed: 51 return 1; 52 case SparseTensorEncodingAttr::DimLevelType::Singleton: 53 return 2; 54 } 55 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 56 } 57 58 /// Returns integers of given width and values as a constant tensor. 59 /// We cast the static shape into a dynamic shape to ensure that the 60 /// method signature remains uniform accross different tensor dimensions. 61 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 62 Location loc, ArrayRef<APInt> values) { 63 Type etp = rewriter.getIntegerType(width); 64 unsigned sz = values.size(); 65 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 66 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 67 auto elts = 68 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 69 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 70 } 71 72 /// Returns function reference (first hit also inserts into module). 73 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 74 ValueRange operands) { 75 MLIRContext *context = op->getContext(); 76 auto module = op->getParentOfType<ModuleOp>(); 77 auto func = module.lookupSymbol<FuncOp>(name); 78 if (!func) { 79 OpBuilder moduleBuilder(module.getBodyRegion()); 80 moduleBuilder 81 .create<FuncOp>(op->getLoc(), name, 82 FunctionType::get(context, operands.getTypes(), result)) 83 .setPrivate(); 84 } 85 return SymbolRefAttr::get(context, name); 86 } 87 88 /// Sparse conversion rule for returns. 89 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 90 public: 91 using OpConversionPattern::OpConversionPattern; 92 LogicalResult 93 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 94 ConversionPatternRewriter &rewriter) const override { 95 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 96 return success(); 97 } 98 }; 99 100 /// Sparse conversion rule for dimension accesses. 101 class SparseTensorToDimSizeConverter 102 : public OpConversionPattern<memref::DimOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 107 ConversionPatternRewriter &rewriter) const override { 108 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 109 return failure(); 110 Type resType = op.getType(); 111 StringRef name = "sparseDimSize"; 112 rewriter.replaceOpWithNewOp<CallOp>( 113 op, resType, getFunc(op, name, resType, operands), operands); 114 return success(); 115 } 116 }; 117 118 /// Sparse conversion rule for the new operator. 119 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 120 using OpConversionPattern::OpConversionPattern; 121 LogicalResult 122 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 123 ConversionPatternRewriter &rewriter) const override { 124 Location loc = op.getLoc(); 125 Type resType = op.getType(); 126 Type eltType = resType.cast<ShapedType>().getElementType(); 127 MLIRContext *context = op->getContext(); 128 SmallVector<Value, 5> params; 129 // Sparse encoding. 130 auto enc = getSparseTensorEncoding(resType); 131 if (!enc) 132 return failure(); 133 // User pointer. 134 params.push_back(operands[0]); 135 // Sparsity annotations in tensor constant form. 136 SmallVector<APInt, 4> attrs; 137 unsigned sz = enc.getDimLevelType().size(); 138 for (unsigned i = 0; i < sz; i++) 139 attrs.push_back( 140 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 141 params.push_back(getTensor(rewriter, 8, loc, attrs)); 142 // Dimension order permutation array. This is the "identity" 143 // permutation by default, or otherwise the "reverse" permutation 144 // of a given ordering, so that indices can be mapped quickly 145 // to the right position. 146 SmallVector<APInt, 4> perm(sz); 147 AffineMap p = enc.getDimOrdering(); 148 if (p) { 149 assert(p.isPermutation() && p.getNumResults() == sz); 150 for (unsigned i = 0; i < sz; i++) 151 perm[p.getDimPosition(i)] = APInt(64, i); 152 } else { 153 for (unsigned i = 0; i < sz; i++) 154 perm[i] = APInt(64, i); 155 } 156 params.push_back(getTensor(rewriter, 64, loc, perm)); 157 // Secondary and primary types encoding. 158 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 159 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 160 unsigned primary; 161 if (eltType.isF64()) 162 primary = 1; 163 else if (eltType.isF32()) 164 primary = 2; 165 else if (eltType.isInteger(32)) 166 primary = 3; 167 else if (eltType.isInteger(16)) 168 primary = 4; 169 else if (eltType.isInteger(8)) 170 primary = 5; 171 else 172 return failure(); 173 params.push_back( 174 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 175 params.push_back( 176 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 177 params.push_back( 178 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 179 // Generate the call to create new tensor. 180 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 181 StringRef name = "newSparseTensor"; 182 rewriter.replaceOpWithNewOp<CallOp>( 183 op, ptrType, getFunc(op, name, ptrType, params), params); 184 return success(); 185 } 186 }; 187 188 /// Sparse conversion rule for pointer accesses. 189 class SparseTensorToPointersConverter 190 : public OpConversionPattern<ToPointersOp> { 191 public: 192 using OpConversionPattern::OpConversionPattern; 193 LogicalResult 194 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 195 ConversionPatternRewriter &rewriter) const override { 196 Type resType = op.getType(); 197 Type eltType = resType.cast<ShapedType>().getElementType(); 198 StringRef name; 199 if (eltType.isIndex()) 200 name = "sparsePointers"; 201 else if (eltType.isInteger(64)) 202 name = "sparsePointers64"; 203 else if (eltType.isInteger(32)) 204 name = "sparsePointers32"; 205 else if (eltType.isInteger(16)) 206 name = "sparsePointers16"; 207 else if (eltType.isInteger(8)) 208 name = "sparsePointers8"; 209 else 210 return failure(); 211 rewriter.replaceOpWithNewOp<CallOp>( 212 op, resType, getFunc(op, name, resType, operands), operands); 213 return success(); 214 } 215 }; 216 217 /// Sparse conversion rule for index accesses. 218 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 219 public: 220 using OpConversionPattern::OpConversionPattern; 221 LogicalResult 222 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const override { 224 Type resType = op.getType(); 225 Type eltType = resType.cast<ShapedType>().getElementType(); 226 StringRef name; 227 if (eltType.isIndex()) 228 name = "sparseIndices"; 229 else if (eltType.isInteger(64)) 230 name = "sparseIndices64"; 231 else if (eltType.isInteger(32)) 232 name = "sparseIndices32"; 233 else if (eltType.isInteger(16)) 234 name = "sparseIndices16"; 235 else if (eltType.isInteger(8)) 236 name = "sparseIndices8"; 237 else 238 return failure(); 239 rewriter.replaceOpWithNewOp<CallOp>( 240 op, resType, getFunc(op, name, resType, operands), operands); 241 return success(); 242 } 243 }; 244 245 /// Sparse conversion rule for value accesses. 246 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 247 public: 248 using OpConversionPattern::OpConversionPattern; 249 LogicalResult 250 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 251 ConversionPatternRewriter &rewriter) const override { 252 Type resType = op.getType(); 253 Type eltType = resType.cast<ShapedType>().getElementType(); 254 StringRef name; 255 if (eltType.isF64()) 256 name = "sparseValuesF64"; 257 else if (eltType.isF32()) 258 name = "sparseValuesF32"; 259 else if (eltType.isInteger(32)) 260 name = "sparseValuesI32"; 261 else if (eltType.isInteger(16)) 262 name = "sparseValuesI16"; 263 else if (eltType.isInteger(8)) 264 name = "sparseValuesI8"; 265 else 266 return failure(); 267 rewriter.replaceOpWithNewOp<CallOp>( 268 op, resType, getFunc(op, name, resType, operands), operands); 269 return success(); 270 } 271 }; 272 273 /// Sparse conversion rule for tensor reconstruction. 274 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 275 public: 276 using OpConversionPattern::OpConversionPattern; 277 LogicalResult 278 // Simply fold the operator into the pointer to the sparse storage scheme. 279 // TODO: generalize this beyond all-dense linearized "sparse" tensors 280 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 281 ConversionPatternRewriter &rewriter) const override { 282 if (auto call = operands[0].getDefiningOp<CallOp>()) { 283 Value arg = call.getOperand(0); 284 if (arg.getType().isa<LLVM::LLVMPointerType>()) { 285 rewriter.replaceOp(op, arg); 286 return success(); 287 } 288 } 289 return failure(); 290 } 291 }; 292 293 } // namespace 294 295 /// Populates the given patterns list with conversion rules required for 296 /// the sparsification of linear algebra operations. 297 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 298 RewritePatternSet &patterns) { 299 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 300 SparseTensorNewConverter, SparseTensorToPointersConverter, 301 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 302 SparseTensorToTensorConverter>(typeConverter, 303 patterns.getContext()); 304 } 305