1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // TensorDialect Attribute Methods.
22 //===----------------------------------------------------------------------===//
23 
24 #define GET_ATTRDEF_CLASSES
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
26 
27 static bool acceptBitWidth(unsigned bitWidth) {
28   switch (bitWidth) {
29   case 0:
30   case 8:
31   case 16:
32   case 32:
33   case 64:
34     return true;
35   default:
36     return false;
37   }
38 }
39 
40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
41   if (failed(parser.parseLess()))
42     return {};
43   // Parse the data as a dictionary.
44   DictionaryAttr dict;
45   if (failed(parser.parseAttribute(dict)))
46     return {};
47   if (failed(parser.parseGreater()))
48     return {};
49   // Process the data from the parsed dictionary value into struct-like data.
50   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
51   AffineMap map = {};
52   unsigned ptr = 0;
53   unsigned ind = 0;
54   for (const NamedAttribute &attr : dict) {
55     if (attr.getName() == "dimLevelType") {
56       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
57       if (!arrayAttr) {
58         parser.emitError(parser.getNameLoc(),
59                          "expected an array for dimension level types");
60         return {};
61       }
62       for (auto i : arrayAttr) {
63         auto strAttr = i.dyn_cast<StringAttr>();
64         if (!strAttr) {
65           parser.emitError(parser.getNameLoc(),
66                            "expected a string value in dimension level types");
67           return {};
68         }
69         auto strVal = strAttr.getValue();
70         if (strVal == "dense") {
71           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
72         } else if (strVal == "compressed") {
73           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
74         } else if (strVal == "singleton") {
75           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
76         } else {
77           parser.emitError(parser.getNameLoc(),
78                            "unexpected dimension level type: ")
79               << strVal;
80           return {};
81         }
82       }
83     } else if (attr.getName() == "dimOrdering") {
84       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
85       if (!affineAttr) {
86         parser.emitError(parser.getNameLoc(),
87                          "expected an affine map for dimension ordering");
88         return {};
89       }
90       map = affineAttr.getValue();
91     } else if (attr.getName() == "pointerBitWidth") {
92       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
93       if (!intAttr) {
94         parser.emitError(parser.getNameLoc(),
95                          "expected an integral pointer bitwidth");
96         return {};
97       }
98       ptr = intAttr.getInt();
99     } else if (attr.getName() == "indexBitWidth") {
100       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
101       if (!intAttr) {
102         parser.emitError(parser.getNameLoc(),
103                          "expected an integral index bitwidth");
104         return {};
105       }
106       ind = intAttr.getInt();
107     } else {
108       parser.emitError(parser.getNameLoc(), "unexpected key: ")
109           << attr.getName().strref();
110       return {};
111     }
112   }
113   // Construct struct-like storage for attribute.
114   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
115                                                      map, ptr, ind);
116 }
117 
118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
119   // Print the struct-like storage in dictionary fashion.
120   printer << "<{ dimLevelType = [ ";
121   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
122     switch (getDimLevelType()[i]) {
123     case DimLevelType::Dense:
124       printer << "\"dense\"";
125       break;
126     case DimLevelType::Compressed:
127       printer << "\"compressed\"";
128       break;
129     case DimLevelType::Singleton:
130       printer << "\"singleton\"";
131       break;
132     }
133     if (i != e - 1)
134       printer << ", ";
135   }
136   printer << " ]";
137   if (getDimOrdering())
138     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
139   printer << ", pointerBitWidth = " << getPointerBitWidth()
140           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
141 }
142 
143 LogicalResult SparseTensorEncodingAttr::verify(
144     function_ref<InFlightDiagnostic()> emitError,
145     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
146     unsigned pointerBitWidth, unsigned indexBitWidth) {
147   if (!acceptBitWidth(pointerBitWidth))
148     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
149   if (!acceptBitWidth(indexBitWidth))
150     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
151   if (dimOrdering) {
152     if (!dimOrdering.isPermutation())
153       return emitError()
154              << "expected a permutation affine map for dimension ordering";
155     if (dimOrdering.getNumResults() != dimLevelType.size())
156       return emitError() << "unexpected mismatch in ordering and dimension "
157                             "level types size";
158   }
159   return success();
160 }
161 
162 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
163     ArrayRef<int64_t> shape, Type elementType,
164     function_ref<InFlightDiagnostic()> emitError) const {
165   // Check structural integrity.
166   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
167                     getPointerBitWidth(), getIndexBitWidth())))
168     return failure();
169   // Check integrity with tensor type specifics. Dimension ordering is optional,
170   // but we always should have dimension level types for the full rank.
171   unsigned size = shape.size();
172   if (size == 0)
173     return emitError() << "expected non-scalar sparse tensor";
174   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
175     return emitError() << "expected an affine map of size " << size
176                        << " for dimension ordering";
177   if (getDimLevelType().size() != size)
178     return emitError() << "expected an array of size " << size
179                        << " for dimension level types";
180   return success();
181 }
182 
183 SparseTensorEncodingAttr
184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
185   if (auto ttp = type.dyn_cast<RankedTensorType>())
186     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
187   return nullptr;
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // TensorDialect Operations.
192 //===----------------------------------------------------------------------===//
193 
194 static LogicalResult isInBounds(Value dim, Value tensor) {
195   IntegerAttr constantAttr;
196   if (matchPattern(dim, m_Constant(&constantAttr))) {
197     unsigned d = constantAttr.getInt();
198     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
199       return failure();
200   }
201   return success(); // in bounds, or symbolic
202 }
203 
204 static LogicalResult isMatchingWidth(Value result, unsigned width) {
205   Type etp = result.getType().cast<MemRefType>().getElementType();
206   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
207     return success();
208   return failure();
209 }
210 
211 LogicalResult NewOp::verify() {
212   if (!getSparseTensorEncoding(result().getType()))
213     return emitError("expected a sparse tensor result");
214   return success();
215 }
216 
217 LogicalResult InitOp::verify() {
218   if (!getSparseTensorEncoding(result().getType()))
219     return emitError("expected a sparse tensor result");
220   RankedTensorType ttp = getType().cast<RankedTensorType>();
221   unsigned rank = ttp.getRank();
222   if (rank != sizes().size())
223     return emitError("unexpected mismatch between tensor rank and sizes: ")
224            << rank << " vs. " << sizes().size();
225   auto shape = ttp.getShape();
226   for (unsigned i = 0; i < rank; i++) {
227     if (shape[i] == ShapedType::kDynamicSize)
228       continue;
229     IntegerAttr constantAttr;
230     if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
231         constantAttr.getInt() != shape[i]) {
232       return emitError("unexpected mismatch with static dimension size ")
233              << shape[i];
234     }
235   }
236   return success();
237 }
238 
239 LogicalResult ConvertOp::verify() {
240   if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
241     if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
242       if (tp1.getRank() != tp2.getRank())
243         return emitError("unexpected conversion mismatch in rank");
244       auto shape1 = tp1.getShape();
245       auto shape2 = tp2.getShape();
246       // Accept size matches between the source and the destination type
247       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
248       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
249       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
250         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
251           return emitError("unexpected conversion mismatch in dimension ") << d;
252       return success();
253     }
254   }
255   return emitError("unexpected type in convert");
256 }
257 
258 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
259   if (getType() == source().getType())
260     return source();
261   return {};
262 }
263 
264 LogicalResult ToPointersOp::verify() {
265   if (auto e = getSparseTensorEncoding(tensor().getType())) {
266     if (failed(isInBounds(dim(), tensor())))
267       return emitError("requested pointers dimension out of bounds");
268     if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
269       return emitError("unexpected type for pointers");
270     return success();
271   }
272   return emitError("expected a sparse tensor to get pointers");
273 }
274 
275 LogicalResult ToIndicesOp::verify() {
276   if (auto e = getSparseTensorEncoding(tensor().getType())) {
277     if (failed(isInBounds(dim(), tensor())))
278       return emitError("requested indices dimension out of bounds");
279     if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
280       return emitError("unexpected type for indices");
281     return success();
282   }
283   return emitError("expected a sparse tensor to get indices");
284 }
285 
286 LogicalResult ToValuesOp::verify() {
287   if (!getSparseTensorEncoding(tensor().getType()))
288     return emitError("expected a sparse tensor to get values");
289   RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
290   MemRefType mtp = result().getType().cast<MemRefType>();
291   if (ttp.getElementType() != mtp.getElementType())
292     return emitError("unexpected mismatch in element types");
293   return success();
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // TensorDialect Management Operations.
298 //===----------------------------------------------------------------------===//
299 
300 LogicalResult LexInsertOp::verify() {
301   if (!getSparseTensorEncoding(tensor().getType()))
302     return emitError("expected a sparse tensor for insertion");
303   return success();
304 }
305 
306 LogicalResult ExpandOp::verify() {
307   if (!getSparseTensorEncoding(tensor().getType()))
308     return emitError("expected a sparse tensor for expansion");
309   return success();
310 }
311 
312 LogicalResult CompressOp::verify() {
313   if (!getSparseTensorEncoding(tensor().getType()))
314     return emitError("expected a sparse tensor for compression");
315   return success();
316 }
317 
318 LogicalResult LoadOp::verify() {
319   if (!getSparseTensorEncoding(tensor().getType()))
320     return emitError("expected a sparse tensor to materialize");
321   return success();
322 }
323 
324 LogicalResult ReleaseOp::verify() {
325   if (!getSparseTensorEncoding(tensor().getType()))
326     return emitError("expected a sparse tensor to release");
327   return success();
328 }
329 
330 LogicalResult OutOp::verify() {
331   if (!getSparseTensorEncoding(tensor().getType()))
332     return emitError("expected a sparse tensor for output");
333   return success();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TensorDialect Linalg.Generic Operations.
338 //===----------------------------------------------------------------------===//
339 
340 template <class T>
341 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
342                                         const char *regionName,
343                                         TypeRange inputTypes, Type outputType) {
344   unsigned numArgs = region.getNumArguments();
345   unsigned expectedNum = inputTypes.size();
346   if (numArgs != expectedNum)
347     return op->emitError() << regionName << " region must have exactly "
348                            << expectedNum << " arguments";
349 
350   for (unsigned i = 0; i < numArgs; i++) {
351     Type typ = region.getArgument(i).getType();
352     if (typ != inputTypes[i])
353       return op->emitError() << regionName << " region argument " << (i + 1)
354                              << " type mismatch";
355   }
356   Operation *term = region.front().getTerminator();
357   YieldOp yield = dyn_cast<YieldOp>(term);
358   if (!yield)
359     return op->emitError() << regionName
360                            << " region must end with sparse_tensor.yield";
361   if (yield.getOperand().getType() != outputType)
362     return op->emitError() << regionName << " region yield type mismatch";
363 
364   return success();
365 }
366 
367 LogicalResult BinaryOp::verify() {
368   NamedAttrList attrs = (*this)->getAttrs();
369   Type leftType = x().getType();
370   Type rightType = y().getType();
371   Type outputType = output().getType();
372   Region &overlap = overlapRegion();
373   Region &left = leftRegion();
374   Region &right = rightRegion();
375 
376   // Check correct number of block arguments and return type for each
377   // non-empty region.
378   LogicalResult regionResult = success();
379   if (!overlap.empty()) {
380     regionResult = verifyNumBlockArgs(
381         this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
382     if (failed(regionResult))
383       return regionResult;
384   }
385   if (!left.empty()) {
386     regionResult =
387         verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
388     if (failed(regionResult))
389       return regionResult;
390   } else if (left_identity()) {
391     if (leftType != outputType)
392       return emitError("left=identity requires first argument to have the same "
393                        "type as the output");
394   }
395   if (!right.empty()) {
396     regionResult = verifyNumBlockArgs(this, right, "right",
397                                       TypeRange{rightType}, outputType);
398     if (failed(regionResult))
399       return regionResult;
400   } else if (right_identity()) {
401     if (rightType != outputType)
402       return emitError("right=identity requires second argument to have the "
403                        "same type as the output");
404   }
405 
406   return success();
407 }
408 
409 LogicalResult UnaryOp::verify() {
410   Type inputType = x().getType();
411   Type outputType = output().getType();
412   LogicalResult regionResult = success();
413 
414   // Check correct number of block arguments and return type for each
415   // non-empty region.
416   Region &present = presentRegion();
417   if (!present.empty()) {
418     regionResult = verifyNumBlockArgs(this, present, "present",
419                                       TypeRange{inputType}, outputType);
420     if (failed(regionResult))
421       return regionResult;
422   }
423   Region &absent = absentRegion();
424   if (!absent.empty()) {
425     regionResult =
426         verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
427     if (failed(regionResult))
428       return regionResult;
429   }
430 
431   return success();
432 }
433 
434 LogicalResult YieldOp::verify() {
435   // Check for compatible parent.
436   auto *parentOp = (*this)->getParentOp();
437   if (auto binaryOp = dyn_cast<BinaryOp>(parentOp))
438     return success();
439   if (auto unaryOp = dyn_cast<UnaryOp>(parentOp))
440     return success();
441 
442   return emitOpError("expected parent op to be sparse_tensor binary or unary");
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // TensorDialect Methods.
447 //===----------------------------------------------------------------------===//
448 
449 void SparseTensorDialect::initialize() {
450   addAttributes<
451 #define GET_ATTRDEF_LIST
452 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
453       >();
454   addOperations<
455 #define GET_OP_LIST
456 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
457       >();
458 }
459 
460 #define GET_OP_CLASSES
461 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
462 
463 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
464