1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Convert sparse tensor primitives to calls into a runtime support library. 10 // Note that this is a current implementation choice to keep the conversion 11 // simple. In principle, these primitives could also be converted to actual 12 // elaborate IR code that implements the primitives on the selected sparse 13 // tensor storage schemes. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 namespace { 29 30 /// Internal encoding of primary storage. Keep this enum consistent 31 /// with the equivalent enum in the sparse runtime support library. 32 enum PrimaryTypeEnum : uint64_t { 33 kF64 = 1, 34 kF32 = 2, 35 kI64 = 3, 36 kI32 = 4, 37 kI16 = 5, 38 kI8 = 6 39 }; 40 41 /// Returns internal type encoding for overhead storage. Keep these 42 /// values consistent with the sparse runtime support library. 43 static unsigned getOverheadTypeEncoding(unsigned width) { 44 switch (width) { 45 default: 46 return 1; 47 case 32: 48 return 2; 49 case 16: 50 return 3; 51 case 8: 52 return 4; 53 } 54 } 55 56 /// Returns internal dimension level type encoding. Keep these 57 /// values consistent with the sparse runtime support library. 58 static unsigned 59 getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) { 60 switch (dlt) { 61 case SparseTensorEncodingAttr::DimLevelType::Dense: 62 return 0; 63 case SparseTensorEncodingAttr::DimLevelType::Compressed: 64 return 1; 65 case SparseTensorEncodingAttr::DimLevelType::Singleton: 66 return 2; 67 } 68 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 69 } 70 71 /// Returns integers of given width and values as a constant tensor. 72 /// We cast the static shape into a dynamic shape to ensure that the 73 /// method signature remains uniform accross different tensor dimensions. 74 static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, 75 Location loc, ArrayRef<APInt> values) { 76 Type etp = rewriter.getIntegerType(width); 77 unsigned sz = values.size(); 78 RankedTensorType tt1 = RankedTensorType::get({sz}, etp); 79 RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); 80 auto elts = 81 rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values)); 82 return rewriter.create<tensor::CastOp>(loc, tt2, elts); 83 } 84 85 /// Returns function reference (first hit also inserts into module). 86 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, 87 ValueRange operands) { 88 MLIRContext *context = op->getContext(); 89 auto module = op->getParentOfType<ModuleOp>(); 90 auto func = module.lookupSymbol<FuncOp>(name); 91 if (!func) { 92 OpBuilder moduleBuilder(module.getBodyRegion()); 93 moduleBuilder 94 .create<FuncOp>(op->getLoc(), name, 95 FunctionType::get(context, operands.getTypes(), result)) 96 .setPrivate(); 97 } 98 return SymbolRefAttr::get(context, name); 99 } 100 101 /// Sparse conversion rule for returns. 102 class SparseReturnConverter : public OpConversionPattern<ReturnOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 107 ConversionPatternRewriter &rewriter) const override { 108 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 109 return success(); 110 } 111 }; 112 113 /// Sparse conversion rule for dimension accesses. 114 class SparseTensorToDimSizeConverter 115 : public OpConversionPattern<tensor::DimOp> { 116 public: 117 using OpConversionPattern::OpConversionPattern; 118 LogicalResult 119 matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands, 120 ConversionPatternRewriter &rewriter) const override { 121 if (!operands[0].getType().isa<LLVM::LLVMPointerType>()) 122 return failure(); 123 Type resType = op.getType(); 124 StringRef name = "sparseDimSize"; 125 rewriter.replaceOpWithNewOp<CallOp>( 126 op, resType, getFunc(op, name, resType, operands), operands); 127 return success(); 128 } 129 }; 130 131 /// Sparse conversion rule for the new operator. 132 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 133 using OpConversionPattern::OpConversionPattern; 134 LogicalResult 135 matchAndRewrite(NewOp op, ArrayRef<Value> operands, 136 ConversionPatternRewriter &rewriter) const override { 137 Location loc = op.getLoc(); 138 Type resType = op.getType(); 139 Type eltType = resType.cast<ShapedType>().getElementType(); 140 MLIRContext *context = op->getContext(); 141 SmallVector<Value, 5> params; 142 // Sparse encoding. 143 auto enc = getSparseTensorEncoding(resType); 144 if (!enc) 145 return failure(); 146 // User pointer. 147 params.push_back(operands[0]); 148 // Sparsity annotations in tensor constant form. 149 SmallVector<APInt, 4> attrs; 150 unsigned sz = enc.getDimLevelType().size(); 151 for (unsigned i = 0; i < sz; i++) 152 attrs.push_back( 153 APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); 154 params.push_back(getTensor(rewriter, 8, loc, attrs)); 155 // Dimension order permutation array. This is the "identity" 156 // permutation by default, or otherwise the "reverse" permutation 157 // of a given ordering, so that indices can be mapped quickly 158 // to the right position. 159 SmallVector<APInt, 4> perm(sz); 160 AffineMap p = enc.getDimOrdering(); 161 if (p) { 162 assert(p.isPermutation() && p.getNumResults() == sz); 163 for (unsigned i = 0; i < sz; i++) 164 perm[p.getDimPosition(i)] = APInt(64, i); 165 } else { 166 for (unsigned i = 0; i < sz; i++) 167 perm[i] = APInt(64, i); 168 } 169 params.push_back(getTensor(rewriter, 64, loc, perm)); 170 // Secondary and primary types encoding. 171 unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); 172 unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); 173 unsigned primary; 174 if (eltType.isF64()) 175 primary = kF64; 176 else if (eltType.isF32()) 177 primary = kF32; 178 else if (eltType.isInteger(64)) 179 primary = kI64; 180 else if (eltType.isInteger(32)) 181 primary = kI32; 182 else if (eltType.isInteger(16)) 183 primary = kI16; 184 else if (eltType.isInteger(8)) 185 primary = kI8; 186 else 187 return failure(); 188 params.push_back( 189 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr))); 190 params.push_back( 191 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd))); 192 params.push_back( 193 rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary))); 194 // Generate the call to create new tensor. 195 Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 196 StringRef name = "newSparseTensor"; 197 rewriter.replaceOpWithNewOp<CallOp>( 198 op, ptrType, getFunc(op, name, ptrType, params), params); 199 return success(); 200 } 201 }; 202 203 /// Sparse conversion rule for pointer accesses. 204 class SparseTensorToPointersConverter 205 : public OpConversionPattern<ToPointersOp> { 206 public: 207 using OpConversionPattern::OpConversionPattern; 208 LogicalResult 209 matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const override { 211 Type resType = op.getType(); 212 Type eltType = resType.cast<ShapedType>().getElementType(); 213 StringRef name; 214 if (eltType.isIndex()) 215 name = "sparsePointers"; 216 else if (eltType.isInteger(64)) 217 name = "sparsePointers64"; 218 else if (eltType.isInteger(32)) 219 name = "sparsePointers32"; 220 else if (eltType.isInteger(16)) 221 name = "sparsePointers16"; 222 else if (eltType.isInteger(8)) 223 name = "sparsePointers8"; 224 else 225 return failure(); 226 rewriter.replaceOpWithNewOp<CallOp>( 227 op, resType, getFunc(op, name, resType, operands), operands); 228 return success(); 229 } 230 }; 231 232 /// Sparse conversion rule for index accesses. 233 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> { 234 public: 235 using OpConversionPattern::OpConversionPattern; 236 LogicalResult 237 matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands, 238 ConversionPatternRewriter &rewriter) const override { 239 Type resType = op.getType(); 240 Type eltType = resType.cast<ShapedType>().getElementType(); 241 StringRef name; 242 if (eltType.isIndex()) 243 name = "sparseIndices"; 244 else if (eltType.isInteger(64)) 245 name = "sparseIndices64"; 246 else if (eltType.isInteger(32)) 247 name = "sparseIndices32"; 248 else if (eltType.isInteger(16)) 249 name = "sparseIndices16"; 250 else if (eltType.isInteger(8)) 251 name = "sparseIndices8"; 252 else 253 return failure(); 254 rewriter.replaceOpWithNewOp<CallOp>( 255 op, resType, getFunc(op, name, resType, operands), operands); 256 return success(); 257 } 258 }; 259 260 /// Sparse conversion rule for value accesses. 261 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 262 public: 263 using OpConversionPattern::OpConversionPattern; 264 LogicalResult 265 matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands, 266 ConversionPatternRewriter &rewriter) const override { 267 Type resType = op.getType(); 268 Type eltType = resType.cast<ShapedType>().getElementType(); 269 StringRef name; 270 if (eltType.isF64()) 271 name = "sparseValuesF64"; 272 else if (eltType.isF32()) 273 name = "sparseValuesF32"; 274 else if (eltType.isInteger(64)) 275 name = "sparseValuesI64"; 276 else if (eltType.isInteger(32)) 277 name = "sparseValuesI32"; 278 else if (eltType.isInteger(16)) 279 name = "sparseValuesI16"; 280 else if (eltType.isInteger(8)) 281 name = "sparseValuesI8"; 282 else 283 return failure(); 284 rewriter.replaceOpWithNewOp<CallOp>( 285 op, resType, getFunc(op, name, resType, operands), operands); 286 return success(); 287 } 288 }; 289 290 /// Sparse conversion rule for tensor reconstruction. 291 class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> { 292 public: 293 using OpConversionPattern::OpConversionPattern; 294 LogicalResult 295 // Simply fold the operator into the pointer to the sparse storage scheme. 296 matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const override { 298 // Check that all arguments of the tensor reconstruction operators are calls 299 // into the support library that query exactly the same opaque pointer. 300 Value ptr; 301 for (Value op : operands) { 302 if (auto call = op.getDefiningOp<CallOp>()) { 303 Value arg = call.getOperand(0); 304 if (!arg.getType().isa<LLVM::LLVMPointerType>()) 305 return failure(); 306 if (!ptr) 307 ptr = arg; 308 else if (arg != ptr) 309 return failure(); 310 } 311 } 312 // If a single opaque pointer is found, perform the folding. 313 if (!ptr) 314 return failure(); 315 rewriter.replaceOp(op, ptr); 316 return success(); 317 } 318 }; 319 320 } // namespace 321 322 /// Populates the given patterns list with conversion rules required for 323 /// the sparsification of linear algebra operations. 324 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 325 RewritePatternSet &patterns) { 326 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 327 SparseTensorNewConverter, SparseTensorToPointersConverter, 328 SparseTensorToIndicesConverter, SparseTensorToValuesConverter, 329 SparseTensorToTensorConverter>(typeConverter, 330 patterns.getContext()); 331 } 332