1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // TensorDialect Attribute Methods.
23 //===----------------------------------------------------------------------===//
24 
25 #define GET_ATTRDEF_CLASSES
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
27 
28 static bool acceptBitWidth(unsigned bitWidth) {
29   switch (bitWidth) {
30   case 0:
31   case 8:
32   case 16:
33   case 32:
34   case 64:
35     return true;
36   default:
37     return false;
38   }
39 }
40 
41 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
42   if (failed(parser.parseLess()))
43     return {};
44   // Parse the data as a dictionary.
45   DictionaryAttr dict;
46   if (failed(parser.parseAttribute(dict)))
47     return {};
48   if (failed(parser.parseGreater()))
49     return {};
50   // Process the data from the parsed dictionary value into struct-like data.
51   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
52   AffineMap map = {};
53   unsigned ptr = 0;
54   unsigned ind = 0;
55   for (const NamedAttribute &attr : dict) {
56     if (attr.getName() == "dimLevelType") {
57       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
58       if (!arrayAttr) {
59         parser.emitError(parser.getNameLoc(),
60                          "expected an array for dimension level types");
61         return {};
62       }
63       for (auto i : arrayAttr) {
64         auto strAttr = i.dyn_cast<StringAttr>();
65         if (!strAttr) {
66           parser.emitError(parser.getNameLoc(),
67                            "expected a string value in dimension level types");
68           return {};
69         }
70         auto strVal = strAttr.getValue();
71         if (strVal == "dense") {
72           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
73         } else if (strVal == "compressed") {
74           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
75         } else if (strVal == "singleton") {
76           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
77         } else {
78           parser.emitError(parser.getNameLoc(),
79                            "unexpected dimension level type: ")
80               << strVal;
81           return {};
82         }
83       }
84     } else if (attr.getName() == "dimOrdering") {
85       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
86       if (!affineAttr) {
87         parser.emitError(parser.getNameLoc(),
88                          "expected an affine map for dimension ordering");
89         return {};
90       }
91       map = affineAttr.getValue();
92     } else if (attr.getName() == "pointerBitWidth") {
93       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
94       if (!intAttr) {
95         parser.emitError(parser.getNameLoc(),
96                          "expected an integral pointer bitwidth");
97         return {};
98       }
99       ptr = intAttr.getInt();
100     } else if (attr.getName() == "indexBitWidth") {
101       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
102       if (!intAttr) {
103         parser.emitError(parser.getNameLoc(),
104                          "expected an integral index bitwidth");
105         return {};
106       }
107       ind = intAttr.getInt();
108     } else {
109       parser.emitError(parser.getNameLoc(), "unexpected key: ")
110           << attr.getName().strref();
111       return {};
112     }
113   }
114   // Construct struct-like storage for attribute.
115   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
116                                                      map, ptr, ind);
117 }
118 
119 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
120   // Print the struct-like storage in dictionary fashion.
121   printer << "<{ dimLevelType = [ ";
122   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
123     switch (getDimLevelType()[i]) {
124     case DimLevelType::Dense:
125       printer << "\"dense\"";
126       break;
127     case DimLevelType::Compressed:
128       printer << "\"compressed\"";
129       break;
130     case DimLevelType::Singleton:
131       printer << "\"singleton\"";
132       break;
133     }
134     if (i != e - 1)
135       printer << ", ";
136   }
137   printer << " ]";
138   if (getDimOrdering())
139     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
140   printer << ", pointerBitWidth = " << getPointerBitWidth()
141           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
142 }
143 
144 LogicalResult SparseTensorEncodingAttr::verify(
145     function_ref<InFlightDiagnostic()> emitError,
146     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
147     unsigned pointerBitWidth, unsigned indexBitWidth) {
148   if (!acceptBitWidth(pointerBitWidth))
149     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
150   if (!acceptBitWidth(indexBitWidth))
151     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
152   if (dimOrdering) {
153     if (!dimOrdering.isPermutation())
154       return emitError()
155              << "expected a permutation affine map for dimension ordering";
156     if (dimOrdering.getNumResults() != dimLevelType.size())
157       return emitError() << "unexpected mismatch in ordering and dimension "
158                             "level types size";
159   }
160   return success();
161 }
162 
163 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
164     ArrayRef<int64_t> shape, Type elementType,
165     function_ref<InFlightDiagnostic()> emitError) const {
166   // Check structural integrity.
167   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
168                     getPointerBitWidth(), getIndexBitWidth())))
169     return failure();
170   // Check integrity with tensor type specifics. Dimension ordering is optional,
171   // but we always should have dimension level types for the full rank.
172   unsigned size = shape.size();
173   if (size == 0)
174     return emitError() << "expected non-scalar sparse tensor";
175   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
176     return emitError() << "expected an affine map of size " << size
177                        << " for dimension ordering";
178   if (getDimLevelType().size() != size)
179     return emitError() << "expected an array of size " << size
180                        << " for dimension level types";
181   return success();
182 }
183 
184 SparseTensorEncodingAttr
185 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
186   if (auto ttp = type.dyn_cast<RankedTensorType>())
187     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
188   return nullptr;
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // TensorDialect Operations.
193 //===----------------------------------------------------------------------===//
194 
195 static LogicalResult isInBounds(Value dim, Value tensor) {
196   IntegerAttr constantAttr;
197   if (matchPattern(dim, m_Constant(&constantAttr))) {
198     unsigned d = constantAttr.getInt();
199     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
200       return failure();
201   }
202   return success(); // in bounds, or symbolic
203 }
204 
205 static LogicalResult isMatchingWidth(Value result, unsigned width) {
206   Type etp = result.getType().cast<MemRefType>().getElementType();
207   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
208     return success();
209   return failure();
210 }
211 
212 LogicalResult NewOp::verify() {
213   if (!getSparseTensorEncoding(result().getType()))
214     return emitError("expected a sparse tensor result");
215   return success();
216 }
217 
218 LogicalResult InitOp::verify() {
219   if (!getSparseTensorEncoding(result().getType()))
220     return emitError("expected a sparse tensor result");
221   RankedTensorType ttp = getType().cast<RankedTensorType>();
222   unsigned rank = ttp.getRank();
223   if (rank != sizes().size())
224     return emitError("unexpected mismatch between tensor rank and sizes: ")
225            << rank << " vs. " << sizes().size();
226   auto shape = ttp.getShape();
227   for (unsigned i = 0; i < rank; i++) {
228     if (shape[i] == ShapedType::kDynamicSize)
229       continue;
230     IntegerAttr constantAttr;
231     if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
232         constantAttr.getInt() != shape[i]) {
233       return emitError("unexpected mismatch with static dimension size ")
234              << shape[i];
235     }
236   }
237   return success();
238 }
239 
240 LogicalResult ConvertOp::verify() {
241   if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
242     if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
243       if (tp1.getRank() != tp2.getRank())
244         return emitError("unexpected conversion mismatch in rank");
245       auto shape1 = tp1.getShape();
246       auto shape2 = tp2.getShape();
247       // Accept size matches between the source and the destination type
248       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
249       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
250       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
251         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
252           return emitError("unexpected conversion mismatch in dimension ") << d;
253       return success();
254     }
255   }
256   return emitError("unexpected type in convert");
257 }
258 
259 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
260   if (getType() == source().getType())
261     return source();
262   return {};
263 }
264 
265 LogicalResult ToPointersOp::verify() {
266   if (auto e = getSparseTensorEncoding(tensor().getType())) {
267     if (failed(isInBounds(dim(), tensor())))
268       return emitError("requested pointers dimension out of bounds");
269     if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
270       return emitError("unexpected type for pointers");
271     return success();
272   }
273   return emitError("expected a sparse tensor to get pointers");
274 }
275 
276 LogicalResult ToIndicesOp::verify() {
277   if (auto e = getSparseTensorEncoding(tensor().getType())) {
278     if (failed(isInBounds(dim(), tensor())))
279       return emitError("requested indices dimension out of bounds");
280     if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
281       return emitError("unexpected type for indices");
282     return success();
283   }
284   return emitError("expected a sparse tensor to get indices");
285 }
286 
287 LogicalResult ToValuesOp::verify() {
288   if (!getSparseTensorEncoding(tensor().getType()))
289     return emitError("expected a sparse tensor to get values");
290   RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
291   MemRefType mtp = result().getType().cast<MemRefType>();
292   if (ttp.getElementType() != mtp.getElementType())
293     return emitError("unexpected mismatch in element types");
294   return success();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // TensorDialect Management Operations.
299 //===----------------------------------------------------------------------===//
300 
301 LogicalResult LexInsertOp::verify() {
302   if (!getSparseTensorEncoding(tensor().getType()))
303     return emitError("expected a sparse tensor for insertion");
304   return success();
305 }
306 
307 LogicalResult ExpandOp::verify() {
308   if (!getSparseTensorEncoding(tensor().getType()))
309     return emitError("expected a sparse tensor for expansion");
310   return success();
311 }
312 
313 LogicalResult CompressOp::verify() {
314   if (!getSparseTensorEncoding(tensor().getType()))
315     return emitError("expected a sparse tensor for compression");
316   return success();
317 }
318 
319 LogicalResult LoadOp::verify() {
320   if (!getSparseTensorEncoding(tensor().getType()))
321     return emitError("expected a sparse tensor to materialize");
322   return success();
323 }
324 
325 LogicalResult ReleaseOp::verify() {
326   if (!getSparseTensorEncoding(tensor().getType()))
327     return emitError("expected a sparse tensor to release");
328   return success();
329 }
330 
331 LogicalResult OutOp::verify() {
332   if (!getSparseTensorEncoding(tensor().getType()))
333     return emitError("expected a sparse tensor for output");
334   return success();
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // TensorDialect Methods.
339 //===----------------------------------------------------------------------===//
340 
341 void SparseTensorDialect::initialize() {
342   addAttributes<
343 #define GET_ATTRDEF_LIST
344 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
345       >();
346   addOperations<
347 #define GET_OP_LIST
348 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
349       >();
350 }
351 
352 #define GET_OP_CLASSES
353 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
354