1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Returns internal type encoding for overhead storage. 31 static unsigned getOverheadTypeEncoding(unsigned width) { 32 switch (width) { 33 default: 34 return 1; 35 case 32: 36 return 2; 37 case 16: 38 return 3; 39 case 8: 40 return 4; 41 } 42 } 43 44 /// Returns function reference (first hit also inserts into module). 45 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 46 ValueRange operands) { 47 MLIRContext *context = op->getContext(); 48 auto module = op->getParentOfType<ModuleOp>(); 49 auto func = module.lookupSymbol<FuncOp>(name); 50 if (!func) { 51 OpBuilder moduleBuilder(module.getBodyRegion()); 52 moduleBuilder 53 .create<FuncOp>(op->getLoc(), name, 54 FunctionType::get(context, operands.getTypes(), result)) 55 .setPrivate(); 56 } 57 return SymbolRefAttr::get(context, name); 58 } 59 60 /// Sparse conversion rule for returns. 61 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 62 public: 63 using OpConversionPattern::OpConversionPattern; 64 LogicalResult 65 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 66 ConversionPatternRewriter &rewriter) const override { 67 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 68 return success(); 69 } 70 }; 71 72 /// Sparse conversion rule for dimension accesses. 73 class SparseTensorToDimSizeConverter 74 : public OpConversionPattern<memref::DimOp> { 75 public: 76 using OpConversionPattern::OpConversionPattern; 77 LogicalResult 78 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const override { 80 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 81 return failure(); 82 Type resType = op.getType(); 83 StringRef name = "sparseDimSize"; 84 rewriter.replaceOpWithNewOp<CallOp>( 85 op, resType, getFunc(op, name, resType, operands), operands); 86 return success(); 87 } 88 }; 89 90 /// Sparse conversion rule for the new operator. 91 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 92 using OpConversionPattern::OpConversionPattern; 93 LogicalResult 94 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 95 ConversionPatternRewriter &rewriter) const override { 96 Location loc = op.getLoc(); 97 Type resType = op.getType(); 98 Type eltType = resType.cast<ShapedType>().getElementType(); 99 MLIRContext *context = op->getContext(); 100 SmallVector<Value, 5> params; 101 // Sparse encoding. 102 auto enc = getSparseTensorEncoding(resType); 103 if (!enc) 104 return failure(); 105 // User pointer. 106 params.push_back(operands[0]); 107 // Sparsity annotations in tensor constant form. Note that we cast 108 // the static shape into a dynamic shape to ensure that the method 109 // signature remains uniform accross different tensor dimensions. 110 SmallVector<bool, 4> attrs; 111 unsigned sz = enc.getDimLevelType().size(); 112 for (unsigned i = 0; i < sz; i++) 113 attrs.push_back(enc.getDimLevelType()[i] == 114 SparseTensorEncodingAttr::DimLevelType::Compressed); 115 Type etp = rewriter.getIntegerType(1); 116 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 117 RankedTensorType tt2 = 118 RankedTensorType::get({ShapedType::kDynamicSize}, etp); 119 auto elts = 120 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs)); 121 params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts)); 122 // Seconary and primary types encoding. 123 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 124 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 125 unsigned primary; 126 if (eltType.isF64()) 127 primary = 1; 128 else if (eltType.isF32()) 129 primary = 2; 130 else if (eltType.isInteger(32)) 131 primary = 3; 132 else if (eltType.isInteger(16)) 133 primary = 4; 134 else if (eltType.isInteger(8)) 135 primary = 5; 136 else 137 return failure(); 138 params.push_back( 139 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 140 params.push_back( 141 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 142 params.push_back( 143 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 144 // Generate the call to create new tensor. 145 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 146 StringRef name = "newSparseTensor"; 147 rewriter.replaceOpWithNewOp<CallOp>( 148 op, ptrType, getFunc(op, name, ptrType, params), params); 149 return success(); 150 } 151 }; 152 153 /// Sparse conversion rule for pointer accesses. 154 class SparseTensorToPointersConverter 155 : public OpConversionPattern<ToPointersOp> { 156 public: 157 using OpConversionPattern::OpConversionPattern; 158 LogicalResult 159 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 160 ConversionPatternRewriter &rewriter) const override { 161 Type resType = op.getType(); 162 Type eltType = resType.cast<ShapedType>().getElementType(); 163 StringRef name; 164 if (eltType.isIndex()) 165 name = "sparsePointers"; 166 else if (eltType.isInteger(64)) 167 name = "sparsePointers64"; 168 else if (eltType.isInteger(32)) 169 name = "sparsePointers32"; 170 else if (eltType.isInteger(16)) 171 name = "sparsePointers16"; 172 else if (eltType.isInteger(8)) 173 name = "sparsePointers8"; 174 else 175 return failure(); 176 rewriter.replaceOpWithNewOp<CallOp>( 177 op, resType, getFunc(op, name, resType, operands), operands); 178 return success(); 179 } 180 }; 181 182 /// Sparse conversion rule for index accesses. 183 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 184 public: 185 using OpConversionPattern::OpConversionPattern; 186 LogicalResult 187 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 188 ConversionPatternRewriter &rewriter) const override { 189 Type resType = op.getType(); 190 Type eltType = resType.cast<ShapedType>().getElementType(); 191 StringRef name; 192 if (eltType.isIndex()) 193 name = "sparseIndices"; 194 else if (eltType.isInteger(64)) 195 name = "sparseIndices64"; 196 else if (eltType.isInteger(32)) 197 name = "sparseIndices32"; 198 else if (eltType.isInteger(16)) 199 name = "sparseIndices16"; 200 else if (eltType.isInteger(8)) 201 name = "sparseIndices8"; 202 else 203 return failure(); 204 rewriter.replaceOpWithNewOp<CallOp>( 205 op, resType, getFunc(op, name, resType, operands), operands); 206 return success(); 207 } 208 }; 209 210 /// Sparse conversion rule for value accesses. 211 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 212 public: 213 using OpConversionPattern::OpConversionPattern; 214 LogicalResult 215 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 216 ConversionPatternRewriter &rewriter) const override { 217 Type resType = op.getType(); 218 Type eltType = resType.cast<ShapedType>().getElementType(); 219 StringRef name; 220 if (eltType.isF64()) 221 name = "sparseValuesF64"; 222 else if (eltType.isF32()) 223 name = "sparseValuesF32"; 224 else if (eltType.isInteger(32)) 225 name = "sparseValuesI32"; 226 else if (eltType.isInteger(16)) 227 name = "sparseValuesI16"; 228 else if (eltType.isInteger(8)) 229 name = "sparseValuesI8"; 230 else 231 return failure(); 232 rewriter.replaceOpWithNewOp<CallOp>( 233 op, resType, getFunc(op, name, resType, operands), operands); 234 return success(); 235 } 236 }; 237 238 } // namespace 239 240 /// Populates the given patterns list with conversion rules required for 241 /// the sparsification of linear algebra operations. 242 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 243 RewritePatternSet &patterns) { 244 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 245 SparseTensorNewConverter, SparseTensorToPointersConverter, 246 SparseTensorToIndicesConverter, SparseTensorToValuesConverter>( 247 typeConverter, patterns.getContext()); 248 } 249