1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Returns internal type encoding for overhead storage. 31 static unsigned getOverheadTypeEncoding(unsigned width) { 32 switch (width) { 33 default: 34 return 1; 35 case 32: 36 return 2; 37 case 16: 38 return 3; 39 case 8: 40 return 4; 41 } 42 } 43 44 /// Returns internal dimension level type encoding. 45 static unsigned 46 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 47 switch (dlt) { 48 case SparseTensorEncodingAttr::DimLevelType::Dense: 49 return 0; 50 case SparseTensorEncodingAttr::DimLevelType::Compressed: 51 return 1; 52 case SparseTensorEncodingAttr::DimLevelType::Singleton: 53 return 2; 54 } 55 } 56 57 /// Returns integers of given width and values as a constant tensor. 58 /// We cast the static shape into a dynamic shape to ensure that the 59 /// method signature remains uniform accross different tensor dimensions. 60 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 61 Location loc, ArrayRef<APInt> values) { 62 Type etp = rewriter.getIntegerType(width); 63 unsigned sz = values.size(); 64 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 65 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 66 auto elts = 67 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 68 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 69 } 70 71 /// Returns function reference (first hit also inserts into module). 72 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 73 ValueRange operands) { 74 MLIRContext *context = op->getContext(); 75 auto module = op->getParentOfType<ModuleOp>(); 76 auto func = module.lookupSymbol<FuncOp>(name); 77 if (!func) { 78 OpBuilder moduleBuilder(module.getBodyRegion()); 79 moduleBuilder 80 .create<FuncOp>(op->getLoc(), name, 81 FunctionType::get(context, operands.getTypes(), result)) 82 .setPrivate(); 83 } 84 return SymbolRefAttr::get(context, name); 85 } 86 87 /// Sparse conversion rule for returns. 88 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 89 public: 90 using OpConversionPattern::OpConversionPattern; 91 LogicalResult 92 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 93 ConversionPatternRewriter &rewriter) const override { 94 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 95 return success(); 96 } 97 }; 98 99 /// Sparse conversion rule for dimension accesses. 100 class SparseTensorToDimSizeConverter 101 : public OpConversionPattern<memref::DimOp> { 102 public: 103 using OpConversionPattern::OpConversionPattern; 104 LogicalResult 105 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 108 return failure(); 109 Type resType = op.getType(); 110 StringRef name = "sparseDimSize"; 111 rewriter.replaceOpWithNewOp<CallOp>( 112 op, resType, getFunc(op, name, resType, operands), operands); 113 return success(); 114 } 115 }; 116 117 /// Sparse conversion rule for the new operator. 118 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 119 using OpConversionPattern::OpConversionPattern; 120 LogicalResult 121 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 122 ConversionPatternRewriter &rewriter) const override { 123 Location loc = op.getLoc(); 124 Type resType = op.getType(); 125 Type eltType = resType.cast<ShapedType>().getElementType(); 126 MLIRContext *context = op->getContext(); 127 SmallVector<Value, 5> params; 128 // Sparse encoding. 129 auto enc = getSparseTensorEncoding(resType); 130 if (!enc) 131 return failure(); 132 // User pointer. 133 params.push_back(operands[0]); 134 // Sparsity annotations in tensor constant form. 135 SmallVector<APInt, 4> attrs; 136 unsigned sz = enc.getDimLevelType().size(); 137 for (unsigned i = 0; i < sz; i++) 138 attrs.push_back( 139 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 140 params.push_back(getTensor(rewriter, 8, loc, attrs)); 141 // Dimension order permutation array. This is the "identity" 142 // permutation by default, or otherwise the "reverse" permutation 143 // of a given ordering, so that indices can be mapped quickly 144 // to the right position. 145 SmallVector<APInt, 4> perm(sz); 146 AffineMap p = enc.getDimOrdering(); 147 if (p) { 148 assert(p.isPermutation() && p.getNumResults() == sz); 149 for (unsigned i = 0; i < sz; i++) 150 perm[p.getDimPosition(i)] = APInt(64, i); 151 } else { 152 for (unsigned i = 0; i < sz; i++) 153 perm[i] = APInt(64, i); 154 } 155 params.push_back(getTensor(rewriter, 64, loc, perm)); 156 // Secondary and primary types encoding. 157 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 158 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 159 unsigned primary; 160 if (eltType.isF64()) 161 primary = 1; 162 else if (eltType.isF32()) 163 primary = 2; 164 else if (eltType.isInteger(32)) 165 primary = 3; 166 else if (eltType.isInteger(16)) 167 primary = 4; 168 else if (eltType.isInteger(8)) 169 primary = 5; 170 else 171 return failure(); 172 params.push_back( 173 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 174 params.push_back( 175 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 176 params.push_back( 177 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 178 // Generate the call to create new tensor. 179 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 180 StringRef name = "newSparseTensor"; 181 rewriter.replaceOpWithNewOp<CallOp>( 182 op, ptrType, getFunc(op, name, ptrType, params), params); 183 return success(); 184 } 185 }; 186 187 /// Sparse conversion rule for pointer accesses. 188 class SparseTensorToPointersConverter 189 : public OpConversionPattern<ToPointersOp> { 190 public: 191 using OpConversionPattern::OpConversionPattern; 192 LogicalResult 193 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 194 ConversionPatternRewriter &rewriter) const override { 195 Type resType = op.getType(); 196 Type eltType = resType.cast<ShapedType>().getElementType(); 197 StringRef name; 198 if (eltType.isIndex()) 199 name = "sparsePointers"; 200 else if (eltType.isInteger(64)) 201 name = "sparsePointers64"; 202 else if (eltType.isInteger(32)) 203 name = "sparsePointers32"; 204 else if (eltType.isInteger(16)) 205 name = "sparsePointers16"; 206 else if (eltType.isInteger(8)) 207 name = "sparsePointers8"; 208 else 209 return failure(); 210 rewriter.replaceOpWithNewOp<CallOp>( 211 op, resType, getFunc(op, name, resType, operands), operands); 212 return success(); 213 } 214 }; 215 216 /// Sparse conversion rule for index accesses. 217 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 218 public: 219 using OpConversionPattern::OpConversionPattern; 220 LogicalResult 221 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 222 ConversionPatternRewriter &rewriter) const override { 223 Type resType = op.getType(); 224 Type eltType = resType.cast<ShapedType>().getElementType(); 225 StringRef name; 226 if (eltType.isIndex()) 227 name = "sparseIndices"; 228 else if (eltType.isInteger(64)) 229 name = "sparseIndices64"; 230 else if (eltType.isInteger(32)) 231 name = "sparseIndices32"; 232 else if (eltType.isInteger(16)) 233 name = "sparseIndices16"; 234 else if (eltType.isInteger(8)) 235 name = "sparseIndices8"; 236 else 237 return failure(); 238 rewriter.replaceOpWithNewOp<CallOp>( 239 op, resType, getFunc(op, name, resType, operands), operands); 240 return success(); 241 } 242 }; 243 244 /// Sparse conversion rule for value accesses. 245 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 246 public: 247 using OpConversionPattern::OpConversionPattern; 248 LogicalResult 249 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 250 ConversionPatternRewriter &rewriter) const override { 251 Type resType = op.getType(); 252 Type eltType = resType.cast<ShapedType>().getElementType(); 253 StringRef name; 254 if (eltType.isF64()) 255 name = "sparseValuesF64"; 256 else if (eltType.isF32()) 257 name = "sparseValuesF32"; 258 else if (eltType.isInteger(32)) 259 name = "sparseValuesI32"; 260 else if (eltType.isInteger(16)) 261 name = "sparseValuesI16"; 262 else if (eltType.isInteger(8)) 263 name = "sparseValuesI8"; 264 else 265 return failure(); 266 rewriter.replaceOpWithNewOp<CallOp>( 267 op, resType, getFunc(op, name, resType, operands), operands); 268 return success(); 269 } 270 }; 271 272 } // namespace 273 274 /// Populates the given patterns list with conversion rules required for 275 /// the sparsification of linear algebra operations. 276 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 277 RewritePatternSet &patterns) { 278 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 279 SparseTensorNewConverter, SparseTensorToPointersConverter, 280 SparseTensorToIndicesConverter, SparseTensorToValuesConverter>( 281 typeConverter, patterns.getContext()); 282 } 283