1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // TensorDialect Attribute Methods.
22 //===----------------------------------------------------------------------===//
23 
24 #define GET_ATTRDEF_CLASSES
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
26 
27 static bool acceptBitWidth(unsigned bitWidth) {
28   switch (bitWidth) {
29   case 0:
30   case 8:
31   case 16:
32   case 32:
33   case 64:
34     return true;
35   default:
36     return false;
37   }
38 }
39 
40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
41   if (failed(parser.parseLess()))
42     return {};
43   // Parse the data as a dictionary.
44   DictionaryAttr dict;
45   if (failed(parser.parseAttribute(dict)))
46     return {};
47   if (failed(parser.parseGreater()))
48     return {};
49   // Process the data from the parsed dictionary value into struct-like data.
50   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
51   AffineMap map = {};
52   unsigned ptr = 0;
53   unsigned ind = 0;
54   for (const NamedAttribute &attr : dict) {
55     if (attr.getName() == "dimLevelType") {
56       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
57       if (!arrayAttr) {
58         parser.emitError(parser.getNameLoc(),
59                          "expected an array for dimension level types");
60         return {};
61       }
62       for (auto i : arrayAttr) {
63         auto strAttr = i.dyn_cast<StringAttr>();
64         if (!strAttr) {
65           parser.emitError(parser.getNameLoc(),
66                            "expected a string value in dimension level types");
67           return {};
68         }
69         auto strVal = strAttr.getValue();
70         if (strVal == "dense") {
71           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
72         } else if (strVal == "compressed") {
73           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
74         } else if (strVal == "singleton") {
75           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
76         } else {
77           parser.emitError(parser.getNameLoc(),
78                            "unexpected dimension level type: ")
79               << strVal;
80           return {};
81         }
82       }
83     } else if (attr.getName() == "dimOrdering") {
84       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
85       if (!affineAttr) {
86         parser.emitError(parser.getNameLoc(),
87                          "expected an affine map for dimension ordering");
88         return {};
89       }
90       map = affineAttr.getValue();
91     } else if (attr.getName() == "pointerBitWidth") {
92       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
93       if (!intAttr) {
94         parser.emitError(parser.getNameLoc(),
95                          "expected an integral pointer bitwidth");
96         return {};
97       }
98       ptr = intAttr.getInt();
99     } else if (attr.getName() == "indexBitWidth") {
100       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
101       if (!intAttr) {
102         parser.emitError(parser.getNameLoc(),
103                          "expected an integral index bitwidth");
104         return {};
105       }
106       ind = intAttr.getInt();
107     } else {
108       parser.emitError(parser.getNameLoc(), "unexpected key: ")
109           << attr.getName().strref();
110       return {};
111     }
112   }
113   // Construct struct-like storage for attribute.
114   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
115                                                      map, ptr, ind);
116 }
117 
118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
119   // Print the struct-like storage in dictionary fashion.
120   printer << "<{ dimLevelType = [ ";
121   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
122     switch (getDimLevelType()[i]) {
123     case DimLevelType::Dense:
124       printer << "\"dense\"";
125       break;
126     case DimLevelType::Compressed:
127       printer << "\"compressed\"";
128       break;
129     case DimLevelType::Singleton:
130       printer << "\"singleton\"";
131       break;
132     }
133     if (i != e - 1)
134       printer << ", ";
135   }
136   printer << " ]";
137   if (getDimOrdering())
138     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
139   printer << ", pointerBitWidth = " << getPointerBitWidth()
140           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
141 }
142 
143 LogicalResult SparseTensorEncodingAttr::verify(
144     function_ref<InFlightDiagnostic()> emitError,
145     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
146     unsigned pointerBitWidth, unsigned indexBitWidth) {
147   if (!acceptBitWidth(pointerBitWidth))
148     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
149   if (!acceptBitWidth(indexBitWidth))
150     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
151   if (dimOrdering) {
152     if (!dimOrdering.isPermutation())
153       return emitError()
154              << "expected a permutation affine map for dimension ordering";
155     if (dimOrdering.getNumResults() != dimLevelType.size())
156       return emitError() << "unexpected mismatch in ordering and dimension "
157                             "level types size";
158   }
159   return success();
160 }
161 
162 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
163     ArrayRef<int64_t> shape, Type elementType,
164     function_ref<InFlightDiagnostic()> emitError) const {
165   // Check structural integrity.
166   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
167                     getPointerBitWidth(), getIndexBitWidth())))
168     return failure();
169   // Check integrity with tensor type specifics. Dimension ordering is optional,
170   // but we always should have dimension level types for the full rank.
171   unsigned size = shape.size();
172   if (size == 0)
173     return emitError() << "expected non-scalar sparse tensor";
174   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
175     return emitError() << "expected an affine map of size " << size
176                        << " for dimension ordering";
177   if (getDimLevelType().size() != size)
178     return emitError() << "expected an array of size " << size
179                        << " for dimension level types";
180   return success();
181 }
182 
183 SparseTensorEncodingAttr
184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
185   if (auto ttp = type.dyn_cast<RankedTensorType>())
186     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
187   return nullptr;
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // TensorDialect Operations.
192 //===----------------------------------------------------------------------===//
193 
194 static LogicalResult isInBounds(Value dim, Value tensor) {
195   IntegerAttr constantAttr;
196   if (matchPattern(dim, m_Constant(&constantAttr))) {
197     unsigned d = constantAttr.getInt();
198     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
199       return failure();
200   }
201   return success(); // in bounds, or symbolic
202 }
203 
204 static LogicalResult isMatchingWidth(Value result, unsigned width) {
205   Type etp = result.getType().cast<MemRefType>().getElementType();
206   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
207     return success();
208   return failure();
209 }
210 
211 LogicalResult NewOp::verify() {
212   if (!getSparseTensorEncoding(result().getType()))
213     return emitError("expected a sparse tensor result");
214   return success();
215 }
216 
217 LogicalResult InitOp::verify() {
218   if (!getSparseTensorEncoding(result().getType()))
219     return emitError("expected a sparse tensor result");
220   RankedTensorType ttp = getType().cast<RankedTensorType>();
221   unsigned rank = ttp.getRank();
222   if (rank != sizes().size())
223     return emitError("unexpected mismatch between tensor rank and sizes: ")
224            << rank << " vs. " << sizes().size();
225   auto shape = ttp.getShape();
226   for (unsigned i = 0; i < rank; i++) {
227     if (shape[i] == ShapedType::kDynamicSize)
228       continue;
229     IntegerAttr constantAttr;
230     if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
231         constantAttr.getInt() != shape[i]) {
232       return emitError("unexpected mismatch with static dimension size ")
233              << shape[i];
234     }
235   }
236   return success();
237 }
238 
239 LogicalResult ConvertOp::verify() {
240   if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
241     if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
242       if (tp1.getRank() != tp2.getRank())
243         return emitError("unexpected conversion mismatch in rank");
244       auto shape1 = tp1.getShape();
245       auto shape2 = tp2.getShape();
246       // Accept size matches between the source and the destination type
247       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
248       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
249       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
250         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
251           return emitError("unexpected conversion mismatch in dimension ") << d;
252       return success();
253     }
254   }
255   return emitError("unexpected type in convert");
256 }
257 
258 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
259   if (getType() == source().getType())
260     return source();
261   return {};
262 }
263 
264 LogicalResult ToPointersOp::verify() {
265   if (auto e = getSparseTensorEncoding(tensor().getType())) {
266     if (failed(isInBounds(dim(), tensor())))
267       return emitError("requested pointers dimension out of bounds");
268     if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
269       return emitError("unexpected type for pointers");
270     return success();
271   }
272   return emitError("expected a sparse tensor to get pointers");
273 }
274 
275 LogicalResult ToIndicesOp::verify() {
276   if (auto e = getSparseTensorEncoding(tensor().getType())) {
277     if (failed(isInBounds(dim(), tensor())))
278       return emitError("requested indices dimension out of bounds");
279     if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
280       return emitError("unexpected type for indices");
281     return success();
282   }
283   return emitError("expected a sparse tensor to get indices");
284 }
285 
286 LogicalResult ToValuesOp::verify() {
287   if (!getSparseTensorEncoding(tensor().getType()))
288     return emitError("expected a sparse tensor to get values");
289   RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
290   MemRefType mtp = result().getType().cast<MemRefType>();
291   if (ttp.getElementType() != mtp.getElementType())
292     return emitError("unexpected mismatch in element types");
293   return success();
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // TensorDialect Management Operations.
298 //===----------------------------------------------------------------------===//
299 
300 LogicalResult LexInsertOp::verify() {
301   if (!getSparseTensorEncoding(tensor().getType()))
302     return emitError("expected a sparse tensor for insertion");
303   return success();
304 }
305 
306 LogicalResult ExpandOp::verify() {
307   if (!getSparseTensorEncoding(tensor().getType()))
308     return emitError("expected a sparse tensor for expansion");
309   return success();
310 }
311 
312 LogicalResult CompressOp::verify() {
313   if (!getSparseTensorEncoding(tensor().getType()))
314     return emitError("expected a sparse tensor for compression");
315   return success();
316 }
317 
318 LogicalResult LoadOp::verify() {
319   if (!getSparseTensorEncoding(tensor().getType()))
320     return emitError("expected a sparse tensor to materialize");
321   return success();
322 }
323 
324 LogicalResult ReleaseOp::verify() {
325   if (!getSparseTensorEncoding(tensor().getType()))
326     return emitError("expected a sparse tensor to release");
327   return success();
328 }
329 
330 LogicalResult OutOp::verify() {
331   if (!getSparseTensorEncoding(tensor().getType()))
332     return emitError("expected a sparse tensor for output");
333   return success();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TensorDialect Methods.
338 //===----------------------------------------------------------------------===//
339 
340 void SparseTensorDialect::initialize() {
341   addAttributes<
342 #define GET_ATTRDEF_LIST
343 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
344       >();
345   addOperations<
346 #define GET_OP_LIST
347 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
348       >();
349 }
350 
351 #define GET_OP_CLASSES
352 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
353 
354 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
355