1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 using namespace mlir::sparse_tensor; 26 27 namespace { 28 29 /// Returns internal type encoding for overhead storage. 30 static unsigned getOverheadTypeEncoding(unsigned width) { 31 switch (width) { 32 default: 33 return 1; 34 case 32: 35 return 2; 36 case 16: 37 return 3; 38 case 8: 39 return 4; 40 } 41 } 42 43 /// Returns function reference (first hit also inserts into module). 44 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 45 ValueRange operands) { 46 MLIRContext *context = op->getContext(); 47 auto module = op->getParentOfType<ModuleOp>(); 48 auto func = module.lookupSymbol<FuncOp>(name); 49 if (!func) { 50 OpBuilder moduleBuilder(module.getBodyRegion()); 51 moduleBuilder 52 .create<FuncOp>(op->getLoc(), name, 53 FunctionType::get(context, operands.getTypes(), result)) 54 .setPrivate(); 55 } 56 return SymbolRefAttr::get(context, name); 57 } 58 59 /// Sparse conversion rule for returns. 60 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 61 public: 62 using OpConversionPattern::OpConversionPattern; 63 LogicalResult 64 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 65 ConversionPatternRewriter &rewriter) const override { 66 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 67 return success(); 68 } 69 }; 70 71 /// Sparse conversion rule for dimension accesses. 72 class SparseTensorToDimSizeConverter 73 : public OpConversionPattern<memref::DimOp> { 74 public: 75 using OpConversionPattern::OpConversionPattern; 76 LogicalResult 77 matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands, 78 ConversionPatternRewriter &rewriter) const override { 79 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 80 return failure(); 81 Type resType = op.getType(); 82 StringRef name = "sparseDimSize"; 83 rewriter.replaceOpWithNewOp<CallOp>( 84 op, resType, getFunc(op, name, resType, operands), operands); 85 return success(); 86 } 87 }; 88 89 /// Sparse conversion rule for the new operator. 90 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 91 using OpConversionPattern::OpConversionPattern; 92 LogicalResult 93 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 94 ConversionPatternRewriter &rewriter) const override { 95 Location loc = op.getLoc(); 96 Type resType = op.getType(); 97 Type eltType = resType.cast<ShapedType>().getElementType(); 98 MLIRContext *context = op->getContext(); 99 SmallVector<Value, 5> params; 100 // Sparse encoding. 101 auto enc = getSparseTensorEncoding(resType); 102 if (!enc) 103 return failure(); 104 // User pointer. 105 params.push_back(operands[0]); 106 // Sparsity annotations. 107 SmallVector<bool, 4> attrs; 108 unsigned sz = enc.getDimLevelType().size(); 109 for (unsigned i = 0; i < sz; i++) 110 attrs.push_back(enc.getDimLevelType()[i] == 111 SparseTensorEncodingAttr::DimLevelType::Compressed); 112 auto elts = DenseElementsAttr::get( 113 RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs); 114 params.push_back(rewriter.create<ConstantOp>(loc, elts)); 115 // Seconary and primary types encoding. 116 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 117 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 118 unsigned primary; 119 if (eltType.isF64()) 120 primary = 1; 121 else if (eltType.isF32()) 122 primary = 2; 123 else if (eltType.isInteger(32)) 124 primary = 3; 125 else if (eltType.isInteger(16)) 126 primary = 4; 127 else if (eltType.isInteger(8)) 128 primary = 5; 129 else 130 return failure(); 131 params.push_back( 132 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 133 params.push_back( 134 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 135 params.push_back( 136 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 137 // Generate the call to create new tensor. 138 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 139 StringRef name = "newSparseTensor"; 140 rewriter.replaceOpWithNewOp<CallOp>( 141 op, ptrType, getFunc(op, name, ptrType, params), params); 142 return success(); 143 } 144 }; 145 146 /// Sparse conversion rule for pointer accesses. 147 class SparseTensorToPointersConverter 148 : public OpConversionPattern<ToPointersOp> { 149 public: 150 using OpConversionPattern::OpConversionPattern; 151 LogicalResult 152 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 153 ConversionPatternRewriter &rewriter) const override { 154 Type resType = op.getType(); 155 Type eltType = resType.cast<ShapedType>().getElementType(); 156 StringRef name; 157 if (eltType.isIndex()) 158 name = "sparsePointers"; 159 else if (eltType.isInteger(64)) 160 name = "sparsePointers64"; 161 else if (eltType.isInteger(32)) 162 name = "sparsePointers32"; 163 else if (eltType.isInteger(16)) 164 name = "sparsePointers16"; 165 else if (eltType.isInteger(8)) 166 name = "sparsePointers8"; 167 else 168 return failure(); 169 rewriter.replaceOpWithNewOp<CallOp>( 170 op, resType, getFunc(op, name, resType, operands), operands); 171 return success(); 172 } 173 }; 174 175 /// Sparse conversion rule for index accesses. 176 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 177 public: 178 using OpConversionPattern::OpConversionPattern; 179 LogicalResult 180 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 181 ConversionPatternRewriter &rewriter) const override { 182 Type resType = op.getType(); 183 Type eltType = resType.cast<ShapedType>().getElementType(); 184 StringRef name; 185 if (eltType.isIndex()) 186 name = "sparseIndices"; 187 else if (eltType.isInteger(64)) 188 name = "sparseIndices64"; 189 else if (eltType.isInteger(32)) 190 name = "sparseIndices32"; 191 else if (eltType.isInteger(16)) 192 name = "sparseIndices16"; 193 else if (eltType.isInteger(8)) 194 name = "sparseIndices8"; 195 else 196 return failure(); 197 rewriter.replaceOpWithNewOp<CallOp>( 198 op, resType, getFunc(op, name, resType, operands), operands); 199 return success(); 200 } 201 }; 202 203 /// Sparse conversion rule for value accesses. 204 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 205 public: 206 using OpConversionPattern::OpConversionPattern; 207 LogicalResult 208 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 209 ConversionPatternRewriter &rewriter) const override { 210 Type resType = op.getType(); 211 Type eltType = resType.cast<ShapedType>().getElementType(); 212 StringRef name; 213 if (eltType.isF64()) 214 name = "sparseValuesF64"; 215 else if (eltType.isF32()) 216 name = "sparseValuesF32"; 217 else if (eltType.isInteger(32)) 218 name = "sparseValuesI32"; 219 else if (eltType.isInteger(16)) 220 name = "sparseValuesI16"; 221 else if (eltType.isInteger(8)) 222 name = "sparseValuesI8"; 223 else 224 return failure(); 225 rewriter.replaceOpWithNewOp<CallOp>( 226 op, resType, getFunc(op, name, resType, operands), operands); 227 return success(); 228 } 229 }; 230 231 } // namespace 232 233 /// Populates the given patterns list with conversion rules required for 234 /// the sparsification of linear algebra operations. 235 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 236 RewritePatternSet &patterns) { 237 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 238 SparseTensorNewConverter, SparseTensorToPointersConverter, 239 SparseTensorToIndicesConverter, SparseTensorToValuesConverter>( 240 typeConverter, patterns.getContext()); 241 } 242