1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Returns internal type encoding for overhead storage. 31 static unsigned getOverheadTypeEncoding(unsigned width) { 32 switch (width) { 33 default: 34 return 1; 35 case 32: 36 return 2; 37 case 16: 38 return 3; 39 case 8: 40 return 4; 41 } 42 } 43 44 /// Returns internal dimension level type encoding. 45 static unsigned 46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 47 switch (dlt) { 48 case SparseTensorEncodingAttr::DimLevelType::Dense: 49 return 0; 50 case SparseTensorEncodingAttr::DimLevelType::Compressed: 51 return 1; 52 case SparseTensorEncodingAttr::DimLevelType::Singleton: 53 return 2; 54 } 55 } 56 57 /// Returns function reference (first hit also inserts into module). 58 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 59 ValueRange operands) { 60 MLIRContext *context = op->getContext(); 61 auto module = op->getParentOfType<ModuleOp>(); 62 auto func = module.lookupSymbol<FuncOp>(name); 63 if (!func) { 64 OpBuilder moduleBuilder(module.getBodyRegion()); 65 moduleBuilder 66 .create<FuncOp>(op->getLoc(), name, 67 FunctionType::get(context, operands.getTypes(), result)) 68 .setPrivate(); 69 } 70 return SymbolRefAttr::get(context, name); 71 } 72 73 /// Sparse conversion rule for returns. 74 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 75 public: 76 using OpConversionPattern::OpConversionPattern; 77 LogicalResult 78 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const override { 80 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 81 return success(); 82 } 83 }; 84 85 /// Sparse conversion rule for dimension accesses. 86 class SparseTensorToDimSizeConverter 87 : public OpConversionPattern<memref::DimOp> { 88 public: 89 using OpConversionPattern::OpConversionPattern; 90 LogicalResult 91 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 92 ConversionPatternRewriter &rewriter) const override { 93 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 94 return failure(); 95 Type resType = op.getType(); 96 StringRef name = "sparseDimSize"; 97 rewriter.replaceOpWithNewOp<CallOp>( 98 op, resType, getFunc(op, name, resType, operands), operands); 99 return success(); 100 } 101 }; 102 103 /// Sparse conversion rule for the new operator. 104 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 105 using OpConversionPattern::OpConversionPattern; 106 LogicalResult 107 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 108 ConversionPatternRewriter &rewriter) const override { 109 Location loc = op.getLoc(); 110 Type resType = op.getType(); 111 Type eltType = resType.cast<ShapedType>().getElementType(); 112 MLIRContext *context = op->getContext(); 113 SmallVector<Value, 5> params; 114 // Sparse encoding. 115 auto enc = getSparseTensorEncoding(resType); 116 if (!enc) 117 return failure(); 118 // User pointer. 119 params.push_back(operands[0]); 120 // Sparsity annotations in tensor constant form. Note that we cast 121 // the static shape into a dynamic shape to ensure that the method 122 // signature remains uniform accross different tensor dimensions. 123 SmallVector<APInt, 4> attrs; 124 unsigned sz = enc.getDimLevelType().size(); 125 for (unsigned i = 0; i < sz; i++) 126 attrs.push_back( 127 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 128 Type etp = rewriter.getIntegerType(8); 129 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 130 RankedTensorType tt2 = 131 RankedTensorType::get({ShapedType::kDynamicSize}, etp); 132 auto elts = 133 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs)); 134 params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts)); 135 // Seconary and primary types encoding. 136 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 137 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 138 unsigned primary; 139 if (eltType.isF64()) 140 primary = 1; 141 else if (eltType.isF32()) 142 primary = 2; 143 else if (eltType.isInteger(32)) 144 primary = 3; 145 else if (eltType.isInteger(16)) 146 primary = 4; 147 else if (eltType.isInteger(8)) 148 primary = 5; 149 else 150 return failure(); 151 params.push_back( 152 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 153 params.push_back( 154 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 155 params.push_back( 156 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 157 // Generate the call to create new tensor. 158 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 159 StringRef name = "newSparseTensor"; 160 rewriter.replaceOpWithNewOp<CallOp>( 161 op, ptrType, getFunc(op, name, ptrType, params), params); 162 return success(); 163 } 164 }; 165 166 /// Sparse conversion rule for pointer accesses. 167 class SparseTensorToPointersConverter 168 : public OpConversionPattern<ToPointersOp> { 169 public: 170 using OpConversionPattern::OpConversionPattern; 171 LogicalResult 172 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 173 ConversionPatternRewriter &rewriter) const override { 174 Type resType = op.getType(); 175 Type eltType = resType.cast<ShapedType>().getElementType(); 176 StringRef name; 177 if (eltType.isIndex()) 178 name = "sparsePointers"; 179 else if (eltType.isInteger(64)) 180 name = "sparsePointers64"; 181 else if (eltType.isInteger(32)) 182 name = "sparsePointers32"; 183 else if (eltType.isInteger(16)) 184 name = "sparsePointers16"; 185 else if (eltType.isInteger(8)) 186 name = "sparsePointers8"; 187 else 188 return failure(); 189 rewriter.replaceOpWithNewOp<CallOp>( 190 op, resType, getFunc(op, name, resType, operands), operands); 191 return success(); 192 } 193 }; 194 195 /// Sparse conversion rule for index accesses. 196 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 197 public: 198 using OpConversionPattern::OpConversionPattern; 199 LogicalResult 200 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 201 ConversionPatternRewriter &rewriter) const override { 202 Type resType = op.getType(); 203 Type eltType = resType.cast<ShapedType>().getElementType(); 204 StringRef name; 205 if (eltType.isIndex()) 206 name = "sparseIndices"; 207 else if (eltType.isInteger(64)) 208 name = "sparseIndices64"; 209 else if (eltType.isInteger(32)) 210 name = "sparseIndices32"; 211 else if (eltType.isInteger(16)) 212 name = "sparseIndices16"; 213 else if (eltType.isInteger(8)) 214 name = "sparseIndices8"; 215 else 216 return failure(); 217 rewriter.replaceOpWithNewOp<CallOp>( 218 op, resType, getFunc(op, name, resType, operands), operands); 219 return success(); 220 } 221 }; 222 223 /// Sparse conversion rule for value accesses. 224 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 225 public: 226 using OpConversionPattern::OpConversionPattern; 227 LogicalResult 228 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 229 ConversionPatternRewriter &rewriter) const override { 230 Type resType = op.getType(); 231 Type eltType = resType.cast<ShapedType>().getElementType(); 232 StringRef name; 233 if (eltType.isF64()) 234 name = "sparseValuesF64"; 235 else if (eltType.isF32()) 236 name = "sparseValuesF32"; 237 else if (eltType.isInteger(32)) 238 name = "sparseValuesI32"; 239 else if (eltType.isInteger(16)) 240 name = "sparseValuesI16"; 241 else if (eltType.isInteger(8)) 242 name = "sparseValuesI8"; 243 else 244 return failure(); 245 rewriter.replaceOpWithNewOp<CallOp>( 246 op, resType, getFunc(op, name, resType, operands), operands); 247 return success(); 248 } 249 }; 250 251 } // namespace 252 253 /// Populates the given patterns list with conversion rules required for 254 /// the sparsification of linear algebra operations. 255 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 256 RewritePatternSet &patterns) { 257 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 258 SparseTensorNewConverter, SparseTensorToPointersConverter, 259 SparseTensorToIndicesConverter, SparseTensorToValuesConverter>( 260 typeConverter, patterns.getContext()); 261 } 262