1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // TensorDialect Attribute Methods.
23 //===----------------------------------------------------------------------===//
24 
25 #define GET_ATTRDEF_CLASSES
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
27 
28 static bool acceptBitWidth(unsigned bitWidth) {
29   switch (bitWidth) {
30   case 0:
31   case 8:
32   case 16:
33   case 32:
34   case 64:
35     return true;
36   default:
37     return false;
38   }
39 }
40 
41 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
42   if (failed(parser.parseLess()))
43     return {};
44   // Parse the data as a dictionary.
45   DictionaryAttr dict;
46   if (failed(parser.parseAttribute(dict)))
47     return {};
48   if (failed(parser.parseGreater()))
49     return {};
50   // Process the data from the parsed dictionary value into struct-like data.
51   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
52   AffineMap map = {};
53   unsigned ptr = 0;
54   unsigned ind = 0;
55   for (const NamedAttribute &attr : dict) {
56     if (attr.getName() == "dimLevelType") {
57       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
58       if (!arrayAttr) {
59         parser.emitError(parser.getNameLoc(),
60                          "expected an array for dimension level types");
61         return {};
62       }
63       for (auto i : arrayAttr) {
64         auto strAttr = i.dyn_cast<StringAttr>();
65         if (!strAttr) {
66           parser.emitError(parser.getNameLoc(),
67                            "expected a string value in dimension level types");
68           return {};
69         }
70         auto strVal = strAttr.getValue();
71         if (strVal == "dense") {
72           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
73         } else if (strVal == "compressed") {
74           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
75         } else if (strVal == "singleton") {
76           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
77         } else {
78           parser.emitError(parser.getNameLoc(),
79                            "unexpected dimension level type: ")
80               << strVal;
81           return {};
82         }
83       }
84     } else if (attr.getName() == "dimOrdering") {
85       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
86       if (!affineAttr) {
87         parser.emitError(parser.getNameLoc(),
88                          "expected an affine map for dimension ordering");
89         return {};
90       }
91       map = affineAttr.getValue();
92     } else if (attr.getName() == "pointerBitWidth") {
93       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
94       if (!intAttr) {
95         parser.emitError(parser.getNameLoc(),
96                          "expected an integral pointer bitwidth");
97         return {};
98       }
99       ptr = intAttr.getInt();
100     } else if (attr.getName() == "indexBitWidth") {
101       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
102       if (!intAttr) {
103         parser.emitError(parser.getNameLoc(),
104                          "expected an integral index bitwidth");
105         return {};
106       }
107       ind = intAttr.getInt();
108     } else {
109       parser.emitError(parser.getNameLoc(), "unexpected key: ")
110           << attr.getName().strref();
111       return {};
112     }
113   }
114   // Construct struct-like storage for attribute.
115   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
116                                                      map, ptr, ind);
117 }
118 
119 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
120   // Print the struct-like storage in dictionary fashion.
121   printer << "<{ dimLevelType = [ ";
122   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
123     switch (getDimLevelType()[i]) {
124     case DimLevelType::Dense:
125       printer << "\"dense\"";
126       break;
127     case DimLevelType::Compressed:
128       printer << "\"compressed\"";
129       break;
130     case DimLevelType::Singleton:
131       printer << "\"singleton\"";
132       break;
133     }
134     if (i != e - 1)
135       printer << ", ";
136   }
137   printer << " ]";
138   if (getDimOrdering())
139     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
140   printer << ", pointerBitWidth = " << getPointerBitWidth()
141           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
142 }
143 
144 LogicalResult SparseTensorEncodingAttr::verify(
145     function_ref<InFlightDiagnostic()> emitError,
146     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
147     unsigned pointerBitWidth, unsigned indexBitWidth) {
148   if (!acceptBitWidth(pointerBitWidth))
149     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
150   if (!acceptBitWidth(indexBitWidth))
151     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
152   if (dimOrdering) {
153     if (!dimOrdering.isPermutation())
154       return emitError()
155              << "expected a permutation affine map for dimension ordering";
156     if (dimOrdering.getNumResults() != dimLevelType.size())
157       return emitError() << "unexpected mismatch in ordering and dimension "
158                             "level types size";
159   }
160   return success();
161 }
162 
163 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
164     ArrayRef<int64_t> shape, Type elementType,
165     function_ref<InFlightDiagnostic()> emitError) const {
166   // Check structural integrity.
167   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
168                     getPointerBitWidth(), getIndexBitWidth())))
169     return failure();
170   // Check integrity with tensor type specifics. Dimension ordering is optional,
171   // but we always should have dimension level types for the full rank.
172   unsigned size = shape.size();
173   if (size == 0)
174     return emitError() << "expected non-scalar sparse tensor";
175   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
176     return emitError() << "expected an affine map of size " << size
177                        << " for dimension ordering";
178   if (getDimLevelType().size() != size)
179     return emitError() << "expected an array of size " << size
180                        << " for dimension level types";
181   return success();
182 }
183 
184 SparseTensorEncodingAttr
185 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
186   if (auto ttp = type.dyn_cast<RankedTensorType>())
187     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
188   return nullptr;
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // TensorDialect Operations.
193 //===----------------------------------------------------------------------===//
194 
195 static LogicalResult isInBounds(Value dim, Value tensor) {
196   IntegerAttr constantAttr;
197   if (matchPattern(dim, m_Constant(&constantAttr))) {
198     unsigned d = constantAttr.getInt();
199     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
200       return failure();
201   }
202   return success(); // in bounds, or symbolic
203 }
204 
205 static LogicalResult isMatchingWidth(Value result, unsigned width) {
206   Type etp = result.getType().cast<MemRefType>().getElementType();
207   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
208     return success();
209   return failure();
210 }
211 
212 static LogicalResult verify(NewOp op) {
213   if (!getSparseTensorEncoding(op.result().getType()))
214     return op.emitError("expected a sparse tensor result");
215   return success();
216 }
217 
218 static LogicalResult verify(InitOp op) {
219   if (!getSparseTensorEncoding(op.result().getType()))
220     return op.emitError("expected a sparse tensor result");
221   RankedTensorType ttp = op.getType().cast<RankedTensorType>();
222   unsigned rank = ttp.getRank();
223   if (rank != op.sizes().size())
224     return op.emitError("unexpected mismatch between tensor rank and sizes: ")
225            << rank << " vs. " << op.sizes().size();
226   auto shape = ttp.getShape();
227   for (unsigned i = 0; i < rank; i++) {
228     if (shape[i] == ShapedType::kDynamicSize)
229       continue;
230     IntegerAttr constantAttr;
231     if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
232         constantAttr.getInt() != shape[i]) {
233       return op.emitError("unexpected mismatch with static dimension size ")
234              << shape[i];
235     }
236   }
237   return success();
238 }
239 
240 static LogicalResult verify(ConvertOp op) {
241   if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
242     if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
243       if (tp1.getRank() != tp2.getRank())
244         return op.emitError("unexpected conversion mismatch in rank");
245       auto shape1 = tp1.getShape();
246       auto shape2 = tp2.getShape();
247       // Accept size matches between the source and the destination type
248       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
249       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
250       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
251         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
252           return op.emitError("unexpected conversion mismatch in dimension ")
253                  << d;
254       }
255       return success();
256     }
257   }
258   return op.emitError("unexpected type in convert");
259 }
260 
261 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
262   if (getType() == source().getType())
263     return source();
264   return {};
265 }
266 
267 static LogicalResult verify(ToPointersOp op) {
268   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
269     if (failed(isInBounds(op.dim(), op.tensor())))
270       return op.emitError("requested pointers dimension out of bounds");
271     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
272       return op.emitError("unexpected type for pointers");
273     return success();
274   }
275   return op.emitError("expected a sparse tensor to get pointers");
276 }
277 
278 static LogicalResult verify(ToIndicesOp op) {
279   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
280     if (failed(isInBounds(op.dim(), op.tensor())))
281       return op.emitError("requested indices dimension out of bounds");
282     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
283       return op.emitError("unexpected type for indices");
284     return success();
285   }
286   return op.emitError("expected a sparse tensor to get indices");
287 }
288 
289 static LogicalResult verify(ToValuesOp op) {
290   if (!getSparseTensorEncoding(op.tensor().getType()))
291     return op.emitError("expected a sparse tensor to get values");
292   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
293   MemRefType mtp = op.result().getType().cast<MemRefType>();
294   if (ttp.getElementType() != mtp.getElementType())
295     return op.emitError("unexpected mismatch in element types");
296   return success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // TensorDialect Management Operations.
301 //===----------------------------------------------------------------------===//
302 
303 static LogicalResult verify(LexInsertOp op) {
304   if (!getSparseTensorEncoding(op.tensor().getType()))
305     return op.emitError("expected a sparse tensor for insertion");
306   return success();
307 }
308 
309 static LogicalResult verify(ExpandOp op) {
310   if (!getSparseTensorEncoding(op.tensor().getType()))
311     return op.emitError("expected a sparse tensor for expansion");
312   return success();
313 }
314 
315 static LogicalResult verify(CompressOp op) {
316   if (!getSparseTensorEncoding(op.tensor().getType()))
317     return op.emitError("expected a sparse tensor for compression");
318   return success();
319 }
320 
321 static LogicalResult verify(LoadOp op) {
322   if (!getSparseTensorEncoding(op.tensor().getType()))
323     return op.emitError("expected a sparse tensor to materialize");
324   return success();
325 }
326 
327 static LogicalResult verify(ReleaseOp op) {
328   if (!getSparseTensorEncoding(op.tensor().getType()))
329     return op.emitError("expected a sparse tensor to release");
330   return success();
331 }
332 
333 static LogicalResult verify(OutOp op) {
334   if (!getSparseTensorEncoding(op.tensor().getType()))
335     return op.emitError("expected a sparse tensor for output");
336   return success();
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // TensorDialect Methods.
341 //===----------------------------------------------------------------------===//
342 
343 void SparseTensorDialect::initialize() {
344   addAttributes<
345 #define GET_ATTRDEF_LIST
346 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
347       >();
348   addOperations<
349 #define GET_OP_LIST
350 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
351       >();
352 }
353 
354 #define GET_OP_CLASSES
355 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
356