1 //===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Convert sparse tensor primitives to calls into a runtime support library.
10 // Note that this is a current implementation choice to keep the conversion
11 // simple. In principle, these primitives could also be converted to actual
12 // elaborate IR code that implements the primitives on the selected sparse
13 // tensor storage schemes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 namespace {
29 
30 /// Returns internal type encoding for overhead storage.
31 static unsigned getOverheadTypeEncoding(unsigned width) {
32   switch (width) {
33   default:
34     return 1;
35   case 32:
36     return 2;
37   case 16:
38     return 3;
39   case 8:
40     return 4;
41   }
42 }
43 
44 /// Returns function reference (first hit also inserts into module).
45 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
46                                  ValueRange operands) {
47   MLIRContext *context = op->getContext();
48   auto module = op->getParentOfType<ModuleOp>();
49   auto func = module.lookupSymbol<FuncOp>(name);
50   if (!func) {
51     OpBuilder moduleBuilder(module.getBodyRegion());
52     moduleBuilder
53         .create<FuncOp>(op->getLoc(), name,
54                         FunctionType::get(context, operands.getTypes(), result))
55         .setPrivate();
56   }
57   return SymbolRefAttr::get(context, name);
58 }
59 
60 /// Sparse conversion rule for returns.
61 class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
62 public:
63   using OpConversionPattern::OpConversionPattern;
64   LogicalResult
65   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
66                   ConversionPatternRewriter &rewriter) const override {
67     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
68     return success();
69   }
70 };
71 
72 /// Sparse conversion rule for dimension accesses.
73 class SparseTensorToDimSizeConverter
74     : public OpConversionPattern<memref::DimOp> {
75 public:
76   using OpConversionPattern::OpConversionPattern;
77   LogicalResult
78   matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
79                   ConversionPatternRewriter &rewriter) const override {
80     if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
81       return failure();
82     Type resType = op.getType();
83     StringRef name = "sparseDimSize";
84     rewriter.replaceOpWithNewOp<CallOp>(
85         op, resType, getFunc(op, name, resType, operands), operands);
86     return success();
87   }
88 };
89 
90 /// Sparse conversion rule for the new operator.
91 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
92   using OpConversionPattern::OpConversionPattern;
93   LogicalResult
94   matchAndRewrite(NewOp op, ArrayRef<Value> operands,
95                   ConversionPatternRewriter &rewriter) const override {
96     Location loc = op.getLoc();
97     Type resType = op.getType();
98     Type eltType = resType.cast<ShapedType>().getElementType();
99     MLIRContext *context = op->getContext();
100     SmallVector<Value, 5> params;
101     // Sparse encoding.
102     auto enc = getSparseTensorEncoding(resType);
103     if (!enc)
104       return failure();
105     // User pointer.
106     params.push_back(operands[0]);
107     // Sparsity annotations in tensor constant form. Note that we cast
108     // the static shape into a dynamic shape to ensure that the method
109     // signature remains uniform accross different tensor dimensions.
110     SmallVector<bool, 4> attrs;
111     unsigned sz = enc.getDimLevelType().size();
112     for (unsigned i = 0; i < sz; i++)
113       attrs.push_back(enc.getDimLevelType()[i] ==
114                       SparseTensorEncodingAttr::DimLevelType::Compressed);
115     Type etp = rewriter.getIntegerType(1);
116     RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
117     RankedTensorType tt2 =
118         RankedTensorType::get({ShapedType::kDynamicSize}, etp);
119     auto elts =
120         rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
121     params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
122     // Seconary and primary types encoding.
123     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
124     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
125     unsigned primary;
126     if (eltType.isF64())
127       primary = 1;
128     else if (eltType.isF32())
129       primary = 2;
130     else if (eltType.isInteger(32))
131       primary = 3;
132     else if (eltType.isInteger(16))
133       primary = 4;
134     else if (eltType.isInteger(8))
135       primary = 5;
136     else
137       return failure();
138     params.push_back(
139         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
140     params.push_back(
141         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
142     params.push_back(
143         rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
144     // Generate the call to create new tensor.
145     Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
146     StringRef name = "newSparseTensor";
147     rewriter.replaceOpWithNewOp<CallOp>(
148         op, ptrType, getFunc(op, name, ptrType, params), params);
149     return success();
150   }
151 };
152 
153 /// Sparse conversion rule for pointer accesses.
154 class SparseTensorToPointersConverter
155     : public OpConversionPattern<ToPointersOp> {
156 public:
157   using OpConversionPattern::OpConversionPattern;
158   LogicalResult
159   matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
160                   ConversionPatternRewriter &rewriter) const override {
161     Type resType = op.getType();
162     Type eltType = resType.cast<ShapedType>().getElementType();
163     StringRef name;
164     if (eltType.isIndex())
165       name = "sparsePointers";
166     else if (eltType.isInteger(64))
167       name = "sparsePointers64";
168     else if (eltType.isInteger(32))
169       name = "sparsePointers32";
170     else if (eltType.isInteger(16))
171       name = "sparsePointers16";
172     else if (eltType.isInteger(8))
173       name = "sparsePointers8";
174     else
175       return failure();
176     rewriter.replaceOpWithNewOp<CallOp>(
177         op, resType, getFunc(op, name, resType, operands), operands);
178     return success();
179   }
180 };
181 
182 /// Sparse conversion rule for index accesses.
183 class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
184 public:
185   using OpConversionPattern::OpConversionPattern;
186   LogicalResult
187   matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
188                   ConversionPatternRewriter &rewriter) const override {
189     Type resType = op.getType();
190     Type eltType = resType.cast<ShapedType>().getElementType();
191     StringRef name;
192     if (eltType.isIndex())
193       name = "sparseIndices";
194     else if (eltType.isInteger(64))
195       name = "sparseIndices64";
196     else if (eltType.isInteger(32))
197       name = "sparseIndices32";
198     else if (eltType.isInteger(16))
199       name = "sparseIndices16";
200     else if (eltType.isInteger(8))
201       name = "sparseIndices8";
202     else
203       return failure();
204     rewriter.replaceOpWithNewOp<CallOp>(
205         op, resType, getFunc(op, name, resType, operands), operands);
206     return success();
207   }
208 };
209 
210 /// Sparse conversion rule for value accesses.
211 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
212 public:
213   using OpConversionPattern::OpConversionPattern;
214   LogicalResult
215   matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
216                   ConversionPatternRewriter &rewriter) const override {
217     Type resType = op.getType();
218     Type eltType = resType.cast<ShapedType>().getElementType();
219     StringRef name;
220     if (eltType.isF64())
221       name = "sparseValuesF64";
222     else if (eltType.isF32())
223       name = "sparseValuesF32";
224     else if (eltType.isInteger(32))
225       name = "sparseValuesI32";
226     else if (eltType.isInteger(16))
227       name = "sparseValuesI16";
228     else if (eltType.isInteger(8))
229       name = "sparseValuesI8";
230     else
231       return failure();
232     rewriter.replaceOpWithNewOp<CallOp>(
233         op, resType, getFunc(op, name, resType, operands), operands);
234     return success();
235   }
236 };
237 
238 } // namespace
239 
240 /// Populates the given patterns list with conversion rules required for
241 /// the sparsification of linear algebra operations.
242 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
243                                                   RewritePatternSet &patterns) {
244   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
245                SparseTensorNewConverter, SparseTensorToPointersConverter,
246                SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
247       typeConverter, patterns.getContext());
248 }
249